1 //
2 //  Transformer.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/12/16.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "Transformer.hpp"
10 #include "OpConverter.hpp"
11 #include "MNN_generated.h"
12 using namespace MNN::Express;
13 namespace MNN {
14 namespace Train {
15 
16 class TurnTrainable : public Express::Optimizer {
17 public:
TurnTrainable(Transformer::TrainConfig config)18     TurnTrainable(Transformer::TrainConfig config) {
19         mConfig = std::move(config);
20     }
onMeasure(const std::vector<VARP> & outputs,std::shared_ptr<Parameters> parameters=nullptr)21     virtual Cost onMeasure(const std::vector<VARP>& outputs,
22                            std::shared_ptr<Parameters> parameters = nullptr) override {
23         return Cost();
24     }
onExecute(const std::vector<VARP> & outputs,std::shared_ptr<Parameters> p)25     virtual bool onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> p) override {
26         auto exprs = Variable::getExecuteOrder(outputs);
27         {
28             // Turn convolution be trainable convolution
29             for (auto expr : exprs) {
30                 auto newExpr = OpConverter::convert(expr);
31                 if (newExpr.get() != expr.get()) {
32                     Expr::replace(expr, newExpr);
33                 }
34             }
35         }
36         exprs                = Variable::getExecuteOrder(outputs);
37         auto& variableLimits = mConfig.variableLimits;
38         // Collect Const Variable and turn to Trainable
39         for (auto v : exprs) {
40             if (v->get() == nullptr && VARP::INPUT != v->inputType()) {
41                 auto name = v->name();
42                 auto info = v->outputInfo(0);
43                 if (halide_type_float != info->type.code) {
44                     continue;
45                 }
46                 bool match = variableLimits.empty();
47                 for (auto limit : variableLimits) {
48                     if (name.find(limit) != std::string::npos) {
49                         match = true;
50                         break;
51                     }
52                 }
53                 auto va = Variable::create(v, 0);
54                 if (match) {
55                     MNN_PRINT("Add Variable: %s\n", name.c_str());
56                     va.fix(VARP::TRAINABLE);
57                 } else {
58                     va.fix(VARP::CONSTANT);
59                 }
60             }
61         }
62         return true;
63     }
64 
65 private:
66     Transformer::TrainConfig mConfig;
67 };
68 
turnModelToTrainable(TrainConfig config)69 std::shared_ptr<Express::Optimizer> Transformer::turnModelToTrainable(TrainConfig config) {
70     std::shared_ptr<Express::Optimizer> res;
71     res.reset(new TurnTrainable(std::move(config)));
72     return res;
73 }
74 
75 class InferOptimizer : public Express::Optimizer {
76 public:
InferOptimizer()77     InferOptimizer(){}
onMeasure(const std::vector<VARP> & outputs,std::shared_ptr<Parameters> parameters=nullptr)78     virtual Cost onMeasure(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) override {
79         Cost c;
80         return c;
81     };
82 
onExecute(const std::vector<VARP> & outputs,std::shared_ptr<Parameters> parameters=nullptr)83     virtual bool onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) override {
84         auto exprs = Variable::getExecuteOrder(outputs);
85         for (auto& iter : exprs) {
86             auto op = iter->get();
87             if (nullptr == op) {
88                 continue;
89             }
90             if (op->type() != OpType_ConvInt8 && op->type() != OpType_DepthwiseConvInt8) {
91                 continue;
92             }
93             auto inputExpr = iter->inputs()[0]->expr().first;
94             if (inputExpr->get() == nullptr) {
95                 continue;
96             }
97             if (inputExpr->get()->type() != OpType_FloatToInt8) {
98                 continue;
99             }
100             auto subInputExpr = inputExpr->inputs()[0]->expr().first;
101             if (subInputExpr->get() == nullptr) {
102                 continue;
103             }
104             if (subInputExpr->get()->type() != OpType_Int8ToFloat) {
105                 continue;
106             }
107             //MNN_PRINT("Find direct\n");
108             std::vector<VARP> newInputs = subInputExpr->inputs();
109             auto newExpr = Expr::create(iter->extra(), std::move(newInputs));
110             Expr::replace(iter, newExpr);
111         }
112         return true;
113     }
114 };
115 
turnModelToInfer()116 std::shared_ptr<Express::Optimizer> Transformer::turnModelToInfer() {
117     return std::shared_ptr<Optimizer>(new InferOptimizer);
118 }
119 } // namespace Train
120 } // namespace MNN
121