1 //
2 //  ShapeReduction.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/10.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "shape/SizeComputer.hpp"
10 #include "core/Macro.h"
11 #include "core/TensorUtils.hpp"
12 
13 namespace MNN {
_getRealAxis(int axis,int n)14 static int _getRealAxis(int axis, int n) {
15     if (axis < 0) {
16         return axis + n;
17     }
18     return axis;
19 }
20 class ReductionComputer : public SizeComputer {
21 public:
onComputeSize(const MNN::Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs) const22     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
23                                const std::vector<Tensor*>& outputs) const override {
24         MNN_ASSERT(1 == inputs.size() || 2 == inputs.size());
25         MNN_ASSERT(1 == outputs.size());
26 
27         auto output                                       = outputs[0];
28         TensorUtils::getDescribe(output)->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat;
29         auto reduce                                       = op->main_as_ReductionParam();
30         output->buffer().type = inputs[0]->buffer().type;
31         if (nullptr == reduce->dim() && inputs.size() == 1) {
32             output->buffer().dimensions = 0;
33             return true;
34         }
35         std::set<int> reduceDimSet;
36         if (nullptr != reduce->dim()) {
37             for (int i = 0; i < reduce->dim()->size(); ++i) {
38                 reduceDimSet.insert(_getRealAxis(reduce->dim()->data()[i], inputs[0]->dimensions()));
39             }
40         } else {
41             auto input1 = inputs[1];
42             auto size   = input1->elementSize();
43             auto dims   = input1->host<int32_t>();
44             for (int i = 0; i < size; ++i) {
45                 reduceDimSet.insert(_getRealAxis(dims[i], inputs[0]->dimensions()));
46             }
47         }
48 
49         auto input                = inputs[0];
50         const int inputDimensions = input->dimensions();
51         if (reduceDimSet.find(-1) != reduceDimSet.end()) {
52             // dim set have -1 which mean applying reduction on last dimension
53             reduceDimSet.erase(-1);
54             reduceDimSet.insert(inputDimensions - 1);
55         }
56 
57         std::vector<int> newDims;
58         for (int i = 0; i < inputDimensions; ++i) {
59             if (reduceDimSet.find(i) == reduceDimSet.end()) {
60                 newDims.push_back(input->length(i));
61             } else if (reduce->keepDims()) {
62                 newDims.push_back(1);
63             }
64         }
65         output->buffer().dimensions = (int)newDims.size();
66         for (int i = 0; i < newDims.size(); ++i) {
67             output->buffer().dim[i].extent = newDims[i];
68         }
69         TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat;
70 
71         return true;
72     }
73 };
74 
75 REGISTER_SHAPE_INPUTS(ReductionComputer, OpType_Reduction, {1});
76 } // namespace MNN
77