1 // 2 // ConvertBinaryToElementwise.cpp 3 // MNNConverter 4 // 5 // Created by MNN on 2019/09/06. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include "../PostTreatUtils.hpp" 10 using namespace MNN; 11 class ConvertBinaryToElementwise : public PostConverter { 12 public: onExecute(std::unique_ptr<MNN::NetT> & net) const13 virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override { 14 auto& mNet = net; 15 for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) { 16 auto op = iter->get(); 17 18 if (op->type != MNN::OpType_BinaryOp) { 19 continue; 20 } 21 22 auto param = op->main.AsBinaryOp(); 23 if (param->opType != BinaryOpOperation_MUL && param->opType != BinaryOpOperation_ADD && 24 param->opType != BinaryOpOperation_SUB) { 25 continue; 26 } 27 const int inputNum = op->inputIndexes.size(); 28 DCHECK(inputNum == 2) << "BinaryOp should have two inputs"; 29 30 const int inputIndex0 = op->inputIndexes[0]; 31 auto inputOp0 = PostTreatUtils::_findOpByOutputIndex(inputIndex0, mNet.get()); 32 const int inputIndex1 = op->inputIndexes[1]; 33 auto inputOp1 = PostTreatUtils::_findOpByOutputIndex(inputIndex1, mNet.get()); 34 bool readyToChange = (inputOp0->type == MNN::OpType_Convolution || inputOp0->type == MNN::OpType_Eltwise) && 35 (inputOp1->type == MNN::OpType_Convolution || inputOp1->type == MNN::OpType_Eltwise); 36 37 if (readyToChange) { 38 // convert binary op to elementwise op 39 auto elementParam = new MNN::EltwiseT; 40 switch (param->opType) { 41 case BinaryOpOperation_MUL: 42 elementParam->type = EltwiseType_PROD; 43 break; 44 case BinaryOpOperation_ADD: 45 elementParam->type = EltwiseType_SUM; 46 break; 47 case BinaryOpOperation_SUB: 48 elementParam->type = EltwiseType_SUB; 49 break; 50 default: 51 break; 52 } 53 op->type = MNN::OpType_Eltwise; 54 op->main.Reset(); 55 op->main.type = OpParameter_Eltwise; 56 op->main.value = elementParam; 57 } 58 } 59 return true; 60 } 61 }; 62 static PostConverterRegister<ConvertBinaryToElementwise> __l("ConvertBinaryToElementwise"); 63