1 //
2 //  ShapeMoments.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/02/28.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "shape/SizeComputer.hpp"
10 
11 namespace MNN {
12 class MomentsComputer : public SizeComputer {
13 public:
onComputeSize(const MNN::Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const14     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
15                                const std::vector<Tensor*>& outputs) const override {
16         MNN_ASSERT(1 == inputs.size());
17         MNN_ASSERT(2 == outputs.size());
18 
19         auto input        = inputs[0];
20         auto mean         = outputs[0];
21         auto variance     = outputs[1];
22         auto momentsParam = op->main_as_MomentsParam();
23         mean->buffer().type = input->getType();;
24         variance->buffer().type = input->getType();
25         if (nullptr == momentsParam->dim()) {
26             mean->buffer().dimensions     = 0;
27             variance->buffer().dimensions = 0;
28             TensorUtils::getDescribe(mean)->dimensionFormat = MNN_DATA_FORMAT_NCHW;
29             TensorUtils::getDescribe(variance)->dimensionFormat = MNN_DATA_FORMAT_NCHW;
30             return true;
31         }
32 
33         std::set<int> momentsDims;
34         for (int i = 0; i < momentsParam->dim()->size(); ++i) {
35             momentsDims.insert(momentsParam->dim()->data()[i]);
36         }
37         std::vector<int> outputShape;
38         for (int i = 0; i < input->dimensions(); ++i) {
39             if (momentsDims.find(i) == momentsDims.end()) {
40                 outputShape.push_back(input->length(i));
41             } else if (momentsParam->keepDims()) {
42                 outputShape.push_back(1);
43             }
44         }
45 
46         const auto outputDim          = outputShape.size();
47         mean->buffer().dimensions     = static_cast<int>(outputDim);
48         variance->buffer().dimensions = static_cast<int>(outputDim);
49         for (int i = 0; i < outputDim; ++i) {
50             mean->setLength(i, outputShape[i]);
51             variance->setLength(i, outputShape[i]);
52         }
53         TensorUtils::getDescribe(mean)->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
54         TensorUtils::getDescribe(variance)->dimensionFormat = MNN_DATA_FORMAT_NC4HW4;
55 
56         return true;
57     }
58 };
59 
60 REGISTER_SHAPE(MomentsComputer, OpType_Moments);
61 
62 } // namespace MNN
63