1 //
2 // Transformer.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/12/16.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "Transformer.hpp"
10 #include "OpConverter.hpp"
11 #include "MNN_generated.h"
12 using namespace MNN::Express;
13 namespace MNN {
14 namespace Train {
15
16 class TurnTrainable : public Express::Optimizer {
17 public:
TurnTrainable(Transformer::TrainConfig config)18 TurnTrainable(Transformer::TrainConfig config) {
19 mConfig = std::move(config);
20 }
onMeasure(const std::vector<VARP> & outputs,std::shared_ptr<Parameters> parameters=nullptr)21 virtual Cost onMeasure(const std::vector<VARP>& outputs,
22 std::shared_ptr<Parameters> parameters = nullptr) override {
23 return Cost();
24 }
onExecute(const std::vector<VARP> & outputs,std::shared_ptr<Parameters> p)25 virtual bool onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> p) override {
26 auto exprs = Variable::getExecuteOrder(outputs);
27 {
28 // Turn convolution be trainable convolution
29 for (auto expr : exprs) {
30 auto newExpr = OpConverter::convert(expr);
31 if (newExpr.get() != expr.get()) {
32 Expr::replace(expr, newExpr);
33 }
34 }
35 }
36 exprs = Variable::getExecuteOrder(outputs);
37 auto& variableLimits = mConfig.variableLimits;
38 // Collect Const Variable and turn to Trainable
39 for (auto v : exprs) {
40 if (v->get() == nullptr && VARP::INPUT != v->inputType()) {
41 auto name = v->name();
42 auto info = v->outputInfo(0);
43 if (halide_type_float != info->type.code) {
44 continue;
45 }
46 bool match = variableLimits.empty();
47 for (auto limit : variableLimits) {
48 if (name.find(limit) != std::string::npos) {
49 match = true;
50 break;
51 }
52 }
53 auto va = Variable::create(v, 0);
54 if (match) {
55 MNN_PRINT("Add Variable: %s\n", name.c_str());
56 va.fix(VARP::TRAINABLE);
57 } else {
58 va.fix(VARP::CONSTANT);
59 }
60 }
61 }
62 return true;
63 }
64
65 private:
66 Transformer::TrainConfig mConfig;
67 };
68
turnModelToTrainable(TrainConfig config)69 std::shared_ptr<Express::Optimizer> Transformer::turnModelToTrainable(TrainConfig config) {
70 std::shared_ptr<Express::Optimizer> res;
71 res.reset(new TurnTrainable(std::move(config)));
72 return res;
73 }
74
75 class InferOptimizer : public Express::Optimizer {
76 public:
InferOptimizer()77 InferOptimizer(){}
onMeasure(const std::vector<VARP> & outputs,std::shared_ptr<Parameters> parameters=nullptr)78 virtual Cost onMeasure(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) override {
79 Cost c;
80 return c;
81 };
82
onExecute(const std::vector<VARP> & outputs,std::shared_ptr<Parameters> parameters=nullptr)83 virtual bool onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) override {
84 auto exprs = Variable::getExecuteOrder(outputs);
85 for (auto& iter : exprs) {
86 auto op = iter->get();
87 if (nullptr == op) {
88 continue;
89 }
90 if (op->type() != OpType_ConvInt8 && op->type() != OpType_DepthwiseConvInt8) {
91 continue;
92 }
93 auto inputExpr = iter->inputs()[0]->expr().first;
94 if (inputExpr->get() == nullptr) {
95 continue;
96 }
97 if (inputExpr->get()->type() != OpType_FloatToInt8) {
98 continue;
99 }
100 auto subInputExpr = inputExpr->inputs()[0]->expr().first;
101 if (subInputExpr->get() == nullptr) {
102 continue;
103 }
104 if (subInputExpr->get()->type() != OpType_Int8ToFloat) {
105 continue;
106 }
107 //MNN_PRINT("Find direct\n");
108 std::vector<VARP> newInputs = subInputExpr->inputs();
109 auto newExpr = Expr::create(iter->extra(), std::move(newInputs));
110 Expr::replace(iter, newExpr);
111 }
112 return true;
113 }
114 };
115
turnModelToInfer()116 std::shared_ptr<Express::Optimizer> Transformer::turnModelToInfer() {
117 return std::shared_ptr<Optimizer>(new InferOptimizer);
118 }
119 } // namespace Train
120 } // namespace MNN
121