1 //
2 //  GeometryReshape.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/04/03.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "ConvertUtils.hpp"
10 #include "geometry/GeometryComputer.hpp"
11 namespace MNN {
12 class GeometryReshape : public GeometryComputer {
13 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const14     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
15                            Context& context, CommandBuffer& res) const override {
16         auto input     = inputs[0];
17         auto output    = outputs[0];
18         auto inputDes  = TensorUtils::getDescribe(input);
19         auto outputDes = TensorUtils::getDescribe(output);
20         if (TensorUtils::getDescribe(input)->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) {
21             auto midFormat = op->main_as_Reshape()->dimType();
22             if (MNN_DATA_FORMAT_NHWC == midFormat) {
23                 // Convert to NHWC, reshape, and then convert to NC4HW4
24                 std::shared_ptr<Tensor> nhwc(new Tensor);
25                 TensorUtils::setupTensorInfo(input, nhwc.get(), MNN_DATA_FORMAT_NHWC);
26                 ConvertUtils::compute(input, nhwc.get(), res);
27                 res.extras.emplace_back(nhwc);
28                 std::shared_ptr<Tensor> nhwc2(new Tensor);
29                 TensorUtils::setupTensorInfo(output, nhwc2.get(), MNN_DATA_FORMAT_NHWC);
30                 res.extras.emplace_back(nhwc2);
31                 {
32                     auto inputSlice = TensorUtils::getDescribe(nhwc.get())->regions;
33                     if (inputSlice.empty()) {
34                         // Create Full Refence
35                         Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(nhwc.get());
36                         inputSlice.emplace_back(std::move(totalSlice));
37                     }
38                     TensorUtils::getDescribe(nhwc2.get())->regions    = std::move(inputSlice);
39                     TensorUtils::getDescribe(nhwc2.get())->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
40                 }
41                 ConvertUtils::compute(nhwc2.get(), output, res);
42                 return true;
43             }
44         }
45         auto inputSlice = inputDes->regions;
46         if (inputSlice.empty()) {
47             // Create Full Refence
48             Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(input);
49             inputSlice.emplace_back(std::move(totalSlice));
50         }
51         outputDes->regions    = std::move(inputSlice);
52         outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
53         return true;
54     }
55 };
56 class SingleGeometryComputer : public GeometryComputer {
57 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const58     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
59                            Context& context, CommandBuffer& res) const override {
60         auto input      = inputs[0];
61         auto output     = outputs[0];
62         auto inputDes   = TensorUtils::getDescribe(input);
63         auto outputDes  = TensorUtils::getDescribe(output);
64         auto inputSlice = inputDes->regions;
65         if (inputSlice.empty()) {
66             // Create Full Refence
67             Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(input);
68             inputSlice.emplace_back(std::move(totalSlice));
69         }
70         outputDes->regions    = std::move(inputSlice);
71         outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
72         return true;
73     }
74 };
75 
_create()76 static void _create() {
77     std::shared_ptr<GeometryComputer> comp(new GeometryReshape);
78     GeometryComputer::registerGeometryComputer(comp, {OpType_Reshape});
79     std::shared_ptr<GeometryComputer> _comp(new SingleGeometryComputer);
80     GeometryComputer::registerGeometryComputer(_comp, {OpType_Squeeze, OpType_Unsqueeze, OpType_ExpandDims, OpType_Flatten, OpType_QuantizedReshape});
81 }
82 
83 REGISTER_GEOMETRY(GeometryReshape, _create);
84 }; // namespace MNN
85