1 //
2 //  ConvertBinaryToElementwise.cpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2019/09/06.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "../PostTreatUtils.hpp"
10 using namespace MNN;
11 class ConvertBinaryToElementwise : public PostConverter {
12 public:
onExecute(std::unique_ptr<MNN::NetT> & net) const13     virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
14         auto& mNet = net;
15         for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) {
16             auto op = iter->get();
17 
18             if (op->type != MNN::OpType_BinaryOp) {
19                 continue;
20             }
21 
22             auto param = op->main.AsBinaryOp();
23             if (param->opType != BinaryOpOperation_MUL && param->opType != BinaryOpOperation_ADD &&
24                 param->opType != BinaryOpOperation_SUB) {
25                 continue;
26             }
27             const int inputNum = op->inputIndexes.size();
28             DCHECK(inputNum == 2) << "BinaryOp should have two inputs";
29 
30             const int inputIndex0 = op->inputIndexes[0];
31             auto inputOp0         = PostTreatUtils::_findOpByOutputIndex(inputIndex0, mNet.get());
32             const int inputIndex1 = op->inputIndexes[1];
33             auto inputOp1         = PostTreatUtils::_findOpByOutputIndex(inputIndex1, mNet.get());
34             bool readyToChange = (inputOp0->type == MNN::OpType_Convolution || inputOp0->type == MNN::OpType_Eltwise) &&
35                                  (inputOp1->type == MNN::OpType_Convolution || inputOp1->type == MNN::OpType_Eltwise);
36 
37             if (readyToChange) {
38                 // convert binary op to elementwise op
39                 auto elementParam = new MNN::EltwiseT;
40                 switch (param->opType) {
41                     case BinaryOpOperation_MUL:
42                         elementParam->type = EltwiseType_PROD;
43                         break;
44                     case BinaryOpOperation_ADD:
45                         elementParam->type = EltwiseType_SUM;
46                         break;
47                     case BinaryOpOperation_SUB:
48                         elementParam->type = EltwiseType_SUB;
49                         break;
50                     default:
51                         break;
52                 }
53                 op->type = MNN::OpType_Eltwise;
54                 op->main.Reset();
55                 op->main.type  = OpParameter_Eltwise;
56                 op->main.value = elementParam;
57             }
58         }
59         return true;
60     }
61 };
62 static PostConverterRegister<ConvertBinaryToElementwise> __l("ConvertBinaryToElementwise");
63