1 //
2 //  RemoveInverseTensorConverter.cpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2020/07/09.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "../TemplateMerge.hpp"
10 #include "MNN/expr/ExprCreator.hpp"
11 #include "MNN_generated.h"
12 #include "MergeHelpers.hpp"
13 
14 namespace MNN {
15 namespace Express {
16 
17 class RemoveInverseTensorConverter {
18 public:
19     RemoveInverseTensorConverter();
20 };
21 
RemoveInverseTensorConverter()22 RemoveInverseTensorConverter::RemoveInverseTensorConverter() {
23     auto match = [this](EXPRP expr) -> bool {
24         if (!expr->get() || expr->get()->type() != OpType_ConvertTensor) {
25             return false;
26         }
27         VARP input       = expr->inputs().at(0);
28         EXPRP input_expr = input->expr().first;
29         if (!input_expr->get() || input_expr->get()->type() != OpType_ConvertTensor) {
30             return false;
31         }
32 
33         const auto* convert1_params = input_expr->get()->main_as_TensorConvertInfo();
34         const auto* convert2_params = expr->get()->main_as_TensorConvertInfo();
35         if (convert1_params->source() != convert2_params->dest()) {
36             return false;
37         }
38 
39         return true;
40     };
41 
42     auto fold = [this](EXPRP expr) -> bool {
43         VARP input      = expr->inputs().at(0);
44         auto input_expr = input->expr().first;
45 
46         const auto* convert1_params = input_expr->get()->main_as_TensorConvertInfo();
47         const auto* convert2_params = expr->get()->main_as_TensorConvertInfo();
48         EXPRP new_expr;
49 
50         auto* identity   = new MNN::ExtraT;
51         identity->type   = "Identity";
52         identity->engine = "Tensorflow";
53         std::unique_ptr<MNN::OpT> identity_op(new MNN::OpT);
54         identity_op->name       = expr->name();
55         identity_op->type       = OpType_Extra;
56         identity_op->main.type  = OpParameter_Extra;
57         identity_op->main.value = identity;
58 
59         VARP x   = input_expr->inputs().at(0);
60         new_expr = Expr::create(identity_op.get(), {x});
61 
62         Expr::replace(expr, new_expr);
63         return true /*modified*/;
64     };
65     TemplateMerge::getInstance("Merge").insertTemplate("RemoveInverseTensorConverter", match, fold, PASS_PRIORITY_LOW);
66 }
67 
68 static RemoveInverseTensorConverter g_remove_inverse_tensor_convert;
69 
70 } // namespace Express
71 } // namespace MNN
72