1 // 2 // ShapeBinaryOp.cpp 3 // MNN 4 // 5 // Created by MNN on 2019/01/10. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include "shape/SizeComputer.hpp" 10 #include "core/Macro.h" 11 #include <vector> 12 namespace MNN { 13 class BinaryOpComputer : public SizeComputer { 14 public: outputBool(int operation)15 static bool outputBool(int operation) { 16 if (operation == BinaryOpOperation_GREATER_EQUAL) { 17 return true; 18 } 19 if (operation == BinaryOpOperation_GREATER) { 20 return true; 21 } 22 if (operation == BinaryOpOperation_LESS) { 23 return true; 24 } 25 if (operation == BinaryOpOperation_LESS_EQUAL) { 26 return true; 27 } 28 if (operation == BinaryOpOperation_EQUAL) { 29 return true; 30 } 31 if (operation == BinaryOpOperation_NOTEQUAL) { 32 return true; 33 } 34 return false; 35 } onComputeSize(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const36 virtual bool onComputeSize(const Op* op, const std::vector<Tensor*>& inputs, 37 const std::vector<Tensor*>& outputs) const override { 38 MNN_ASSERT(2 == inputs.size()); 39 MNN_ASSERT(1 == outputs.size()); 40 // set output type & format 41 auto input0 = inputs[0], input1 = inputs[1], output = outputs[0]; 42 auto &buffer = output->buffer(); 43 const auto opType = op->main_as_BinaryOp()->opType(); 44 if (outputBool(opType)) { 45 buffer.type = halide_type_of<int32_t>(); 46 } else { 47 buffer.type = input0->getType(); 48 } 49 if (input0->getType().code != input1->getType().code) { 50 MNN_PRINT("Error for binary op: input0's type != input1's type\n"); 51 return false; 52 } 53 54 if (input0->dimensions() < input1->dimensions()) { 55 auto temp = input0; 56 input0 = input1; 57 input1 = temp; 58 } 59 TensorUtils::getDescribe(output)->dimensionFormat = TensorUtils::getDescribe(input0)->dimensionFormat; 60 return SizeComputer::computeBroadCastDims(op, inputs, outputs); 61 } 62 }; 63 64 REGISTER_SHAPE(BinaryOpComputer, OpType_BinaryOp); 65 } // namespace MNN 66