1 //
2 // ShapeReduction.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/01/10.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "shape/SizeComputer.hpp"
10 #include "core/Macro.h"
11 #include "core/TensorUtils.hpp"
12
13 namespace MNN {
_getRealAxis(int axis,int n)14 static int _getRealAxis(int axis, int n) {
15 if (axis < 0) {
16 return axis + n;
17 }
18 return axis;
19 }
20 class ReductionComputer : public SizeComputer {
21 public:
onComputeSize(const MNN::Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const22 virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
23 const std::vector<Tensor*>& outputs) const override {
24 MNN_ASSERT(1 == inputs.size() || 2 == inputs.size());
25 MNN_ASSERT(1 == outputs.size());
26
27 auto output = outputs[0];
28 TensorUtils::getDescribe(output)->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat;
29 auto reduce = op->main_as_ReductionParam();
30 output->buffer().type = inputs[0]->buffer().type;
31 if (nullptr == reduce->dim() && inputs.size() == 1) {
32 output->buffer().dimensions = 0;
33 return true;
34 }
35 std::set<int> reduceDimSet;
36 if (nullptr != reduce->dim()) {
37 for (int i = 0; i < reduce->dim()->size(); ++i) {
38 reduceDimSet.insert(_getRealAxis(reduce->dim()->data()[i], inputs[0]->dimensions()));
39 }
40 } else {
41 auto input1 = inputs[1];
42 auto size = input1->elementSize();
43 auto dims = input1->host<int32_t>();
44 for (int i = 0; i < size; ++i) {
45 reduceDimSet.insert(_getRealAxis(dims[i], inputs[0]->dimensions()));
46 }
47 }
48
49 auto input = inputs[0];
50 const int inputDimensions = input->dimensions();
51 if (reduceDimSet.find(-1) != reduceDimSet.end()) {
52 // dim set have -1 which mean applying reduction on last dimension
53 reduceDimSet.erase(-1);
54 reduceDimSet.insert(inputDimensions - 1);
55 }
56
57 std::vector<int> newDims;
58 for (int i = 0; i < inputDimensions; ++i) {
59 if (reduceDimSet.find(i) == reduceDimSet.end()) {
60 newDims.push_back(input->length(i));
61 } else if (reduce->keepDims()) {
62 newDims.push_back(1);
63 }
64 }
65 output->buffer().dimensions = (int)newDims.size();
66 for (int i = 0; i < newDims.size(); ++i) {
67 output->buffer().dim[i].extent = newDims[i];
68 }
69 TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat;
70
71 return true;
72 }
73 };
74
75 REGISTER_SHAPE_INPUTS(ReductionComputer, OpType_Reduction, {1});
76 } // namespace MNN
77