1 // 2 // ShapeMoments.cpp 3 // MNN 4 // 5 // Created by MNN on 2019/02/28. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #include "shape/SizeComputer.hpp" 10 11 namespace MNN { 12 class MomentsComputer : public SizeComputer { 13 public: onComputeSize(const MNN::Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const14 virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, 15 const std::vector<Tensor*>& outputs) const override { 16 MNN_ASSERT(1 == inputs.size()); 17 MNN_ASSERT(2 == outputs.size()); 18 19 auto input = inputs[0]; 20 auto mean = outputs[0]; 21 auto variance = outputs[1]; 22 auto momentsParam = op->main_as_MomentsParam(); 23 mean->buffer().type = input->getType();; 24 variance->buffer().type = input->getType(); 25 if (nullptr == momentsParam->dim()) { 26 mean->buffer().dimensions = 0; 27 variance->buffer().dimensions = 0; 28 TensorUtils::getDescribe(mean)->dimensionFormat = MNN_DATA_FORMAT_NCHW; 29 TensorUtils::getDescribe(variance)->dimensionFormat = MNN_DATA_FORMAT_NCHW; 30 return true; 31 } 32 33 std::set<int> momentsDims; 34 for (int i = 0; i < momentsParam->dim()->size(); ++i) { 35 momentsDims.insert(momentsParam->dim()->data()[i]); 36 } 37 std::vector<int> outputShape; 38 for (int i = 0; i < input->dimensions(); ++i) { 39 if (momentsDims.find(i) == momentsDims.end()) { 40 outputShape.push_back(input->length(i)); 41 } else if (momentsParam->keepDims()) { 42 outputShape.push_back(1); 43 } 44 } 45 46 const auto outputDim = outputShape.size(); 47 mean->buffer().dimensions = static_cast<int>(outputDim); 48 variance->buffer().dimensions = static_cast<int>(outputDim); 49 for (int i = 0; i < outputDim; ++i) { 50 mean->setLength(i, outputShape[i]); 51 variance->setLength(i, outputShape[i]); 52 } 53 TensorUtils::getDescribe(mean)->dimensionFormat = MNN_DATA_FORMAT_NC4HW4; 54 TensorUtils::getDescribe(variance)->dimensionFormat = MNN_DATA_FORMAT_NC4HW4; 55 56 return true; 57 } 58 }; 59 60 REGISTER_SHAPE(MomentsComputer, OpType_Moments); 61 62 } // namespace MNN 63