1 // 2 // EliminateQuantAndDequant.cpp 3 // MNNConverter 4 // 5 // Created by MNN on 2020/07/09. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include "../TemplateMerge.hpp" 10 #include "MNN/expr/ExprCreator.hpp" 11 #include "MNN_generated.h" 12 13 namespace MNN { 14 namespace Express { 15 16 class EliminateQuantAndDequant { 17 public: 18 EliminateQuantAndDequant(); 19 }; 20 EliminateQuantAndDequant()21EliminateQuantAndDequant::EliminateQuantAndDequant() { 22 auto match = [this](EXPRP expr) -> bool { 23 if (!expr->get() || (expr->get()->type() != OpType_FloatToInt8 && expr->get()->type() != OpType_Int8ToFloat)) { 24 return false; 25 } 26 VARP input = expr->inputs().at(0); 27 const Op* input_op = input->expr().first->get(); 28 if (!input_op) { 29 return false; 30 } 31 if (expr->get()->type() == OpType_FloatToInt8) { 32 if (input_op->type() != OpType_Int8ToFloat) { 33 return false; 34 } 35 } 36 if (expr->get()->type() == OpType_Int8ToFloat) { 37 if (input_op->type() != OpType_FloatToInt8) { 38 return false; 39 } 40 } 41 return true; 42 }; 43 44 auto fold = [this](EXPRP expr) -> bool { 45 VARP input = expr->inputs().at(0); 46 input = input->expr().first->inputs().at(0); 47 48 auto* identity = new MNN::ExtraT; 49 identity->type = "Identity"; 50 identity->engine = "Tensorflow"; 51 std::unique_ptr<MNN::OpT> identity_op(new MNN::OpT); 52 identity_op->name = expr->name(); 53 identity_op->type = OpType_Extra; 54 identity_op->main.type = OpParameter_Extra; 55 identity_op->main.value = identity; 56 57 EXPRP identity_expr = Expr::create(identity_op.get(), {input}); 58 Expr::replace(expr, identity_expr); 59 return true /*modified*/; 60 }; 61 TemplateMerge::getInstance("Merge").insertTemplate("EliminateQuantAndDequant", match, fold, PASS_PRIORITY_LOW); 62 } 63 64 static EliminateQuantAndDequant g_eliminate_quant_dequant; 65 66 } // namespace Express 67 } // namespace MNN 68