1 // 2 // RemoveInverseTensorConverter.cpp 3 // MNNConverter 4 // 5 // Created by MNN on 2020/07/09. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include "../TemplateMerge.hpp" 10 #include "MNN/expr/ExprCreator.hpp" 11 #include "MNN_generated.h" 12 #include "MergeHelpers.hpp" 13 14 namespace MNN { 15 namespace Express { 16 17 class RemoveInverseTensorConverter { 18 public: 19 RemoveInverseTensorConverter(); 20 }; 21 RemoveInverseTensorConverter()22RemoveInverseTensorConverter::RemoveInverseTensorConverter() { 23 auto match = [this](EXPRP expr) -> bool { 24 if (!expr->get() || expr->get()->type() != OpType_ConvertTensor) { 25 return false; 26 } 27 VARP input = expr->inputs().at(0); 28 EXPRP input_expr = input->expr().first; 29 if (!input_expr->get() || input_expr->get()->type() != OpType_ConvertTensor) { 30 return false; 31 } 32 33 const auto* convert1_params = input_expr->get()->main_as_TensorConvertInfo(); 34 const auto* convert2_params = expr->get()->main_as_TensorConvertInfo(); 35 if (convert1_params->source() != convert2_params->dest()) { 36 return false; 37 } 38 39 return true; 40 }; 41 42 auto fold = [this](EXPRP expr) -> bool { 43 VARP input = expr->inputs().at(0); 44 auto input_expr = input->expr().first; 45 46 const auto* convert1_params = input_expr->get()->main_as_TensorConvertInfo(); 47 const auto* convert2_params = expr->get()->main_as_TensorConvertInfo(); 48 EXPRP new_expr; 49 50 auto* identity = new MNN::ExtraT; 51 identity->type = "Identity"; 52 identity->engine = "Tensorflow"; 53 std::unique_ptr<MNN::OpT> identity_op(new MNN::OpT); 54 identity_op->name = expr->name(); 55 identity_op->type = OpType_Extra; 56 identity_op->main.type = OpParameter_Extra; 57 identity_op->main.value = identity; 58 59 VARP x = input_expr->inputs().at(0); 60 new_expr = Expr::create(identity_op.get(), {x}); 61 62 Expr::replace(expr, new_expr); 63 return true /*modified*/; 64 }; 65 TemplateMerge::getInstance("Merge").insertTemplate("RemoveInverseTensorConverter", match, fold, PASS_PRIORITY_LOW); 66 } 67 68 static RemoveInverseTensorConverter g_remove_inverse_tensor_convert; 69 70 } // namespace Express 71 } // namespace MNN 72