1 //
2 //  EliminateQuantAndDequant.cpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2020/07/09.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "../TemplateMerge.hpp"
10 #include "MNN/expr/ExprCreator.hpp"
11 #include "MNN_generated.h"
12 
13 namespace MNN {
14 namespace Express {
15 
16 class EliminateQuantAndDequant {
17 public:
18     EliminateQuantAndDequant();
19 };
20 
EliminateQuantAndDequant()21 EliminateQuantAndDequant::EliminateQuantAndDequant() {
22     auto match = [this](EXPRP expr) -> bool {
23         if (!expr->get() || (expr->get()->type() != OpType_FloatToInt8 && expr->get()->type() != OpType_Int8ToFloat)) {
24             return false;
25         }
26         VARP input         = expr->inputs().at(0);
27         const Op* input_op = input->expr().first->get();
28         if (!input_op) {
29             return false;
30         }
31         if (expr->get()->type() == OpType_FloatToInt8) {
32             if (input_op->type() != OpType_Int8ToFloat) {
33                 return false;
34             }
35         }
36         if (expr->get()->type() == OpType_Int8ToFloat) {
37             if (input_op->type() != OpType_FloatToInt8) {
38                 return false;
39             }
40         }
41         return true;
42     };
43 
44     auto fold = [this](EXPRP expr) -> bool {
45         VARP input = expr->inputs().at(0);
46         input      = input->expr().first->inputs().at(0);
47 
48         auto* identity   = new MNN::ExtraT;
49         identity->type   = "Identity";
50         identity->engine = "Tensorflow";
51         std::unique_ptr<MNN::OpT> identity_op(new MNN::OpT);
52         identity_op->name       = expr->name();
53         identity_op->type       = OpType_Extra;
54         identity_op->main.type  = OpParameter_Extra;
55         identity_op->main.value = identity;
56 
57         EXPRP identity_expr = Expr::create(identity_op.get(), {input});
58         Expr::replace(expr, identity_expr);
59         return true /*modified*/;
60     };
61     TemplateMerge::getInstance("Merge").insertTemplate("EliminateQuantAndDequant", match, fold, PASS_PRIORITY_LOW);
62 }
63 
64 static EliminateQuantAndDequant g_eliminate_quant_dequant;
65 
66 } // namespace Express
67 } // namespace MNN
68