1 //
2 // GeometryReshape.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/04/03.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "ConvertUtils.hpp"
10 #include "geometry/GeometryComputer.hpp"
11 namespace MNN {
12 class GeometryReshape : public GeometryComputer {
13 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const14 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
15 Context& context, CommandBuffer& res) const override {
16 auto input = inputs[0];
17 auto output = outputs[0];
18 auto inputDes = TensorUtils::getDescribe(input);
19 auto outputDes = TensorUtils::getDescribe(output);
20 if (TensorUtils::getDescribe(input)->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) {
21 auto midFormat = op->main_as_Reshape()->dimType();
22 if (MNN_DATA_FORMAT_NHWC == midFormat) {
23 // Convert to NHWC, reshape, and then convert to NC4HW4
24 std::shared_ptr<Tensor> nhwc(new Tensor);
25 TensorUtils::setupTensorInfo(input, nhwc.get(), MNN_DATA_FORMAT_NHWC);
26 ConvertUtils::compute(input, nhwc.get(), res);
27 res.extras.emplace_back(nhwc);
28 std::shared_ptr<Tensor> nhwc2(new Tensor);
29 TensorUtils::setupTensorInfo(output, nhwc2.get(), MNN_DATA_FORMAT_NHWC);
30 res.extras.emplace_back(nhwc2);
31 {
32 auto inputSlice = TensorUtils::getDescribe(nhwc.get())->regions;
33 if (inputSlice.empty()) {
34 // Create Full Refence
35 Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(nhwc.get());
36 inputSlice.emplace_back(std::move(totalSlice));
37 }
38 TensorUtils::getDescribe(nhwc2.get())->regions = std::move(inputSlice);
39 TensorUtils::getDescribe(nhwc2.get())->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
40 }
41 ConvertUtils::compute(nhwc2.get(), output, res);
42 return true;
43 }
44 }
45 auto inputSlice = inputDes->regions;
46 if (inputSlice.empty()) {
47 // Create Full Refence
48 Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(input);
49 inputSlice.emplace_back(std::move(totalSlice));
50 }
51 outputDes->regions = std::move(inputSlice);
52 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
53 return true;
54 }
55 };
56 class SingleGeometryComputer : public GeometryComputer {
57 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const58 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
59 Context& context, CommandBuffer& res) const override {
60 auto input = inputs[0];
61 auto output = outputs[0];
62 auto inputDes = TensorUtils::getDescribe(input);
63 auto outputDes = TensorUtils::getDescribe(output);
64 auto inputSlice = inputDes->regions;
65 if (inputSlice.empty()) {
66 // Create Full Refence
67 Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(input);
68 inputSlice.emplace_back(std::move(totalSlice));
69 }
70 outputDes->regions = std::move(inputSlice);
71 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
72 return true;
73 }
74 };
75
_create()76 static void _create() {
77 std::shared_ptr<GeometryComputer> comp(new GeometryReshape);
78 GeometryComputer::registerGeometryComputer(comp, {OpType_Reshape});
79 std::shared_ptr<GeometryComputer> _comp(new SingleGeometryComputer);
80 GeometryComputer::registerGeometryComputer(_comp, {OpType_Squeeze, OpType_Unsqueeze, OpType_ExpandDims, OpType_Flatten, OpType_QuantizedReshape});
81 }
82
83 REGISTER_GEOMETRY(GeometryReshape, _create);
84 }; // namespace MNN
85