1 //
2 // ConvertUtils.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/04/03.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "ConvertUtils.hpp"
10 #include "core/OpCommonUtils.hpp"
11 namespace MNN {
compute(Tensor * input,Tensor * output,CommandBuffer & res)12 bool ConvertUtils::compute(Tensor* input, Tensor* output, CommandBuffer& res) {
13 auto inputDes = TensorUtils::getDescribe(input);
14 auto outputDes = TensorUtils::getDescribe(output);
15 auto inputFormat = inputDes->dimensionFormat;
16 auto outputFormat = outputDes->dimensionFormat;
17 if (MNN_DATA_FORMAT_NC4HW4 == inputFormat) {
18 inputFormat = MNN_DATA_FORMAT_NCHW;
19 }
20 if (MNN_DATA_FORMAT_NC4HW4 == outputFormat) {
21 outputFormat = MNN_DATA_FORMAT_NCHW;
22 }
23 auto inputSlice = inputDes->regions;
24 MNN_ASSERT(input->dimensions() >= 1);
25 MNN_ASSERT(output->dimensions() == input->dimensions());
26 if (inputSlice.empty()) {
27 inputSlice.resize(1);
28 // Create Full Refence
29 inputSlice[0] = TensorUtils::makeFullSlice(input);
30 }
31 if (inputFormat == outputFormat || 2 == input->dimensions()) {
32 // No need for treat for NCWH <-> NC4HW4
33 outputDes->regions = std::move(inputSlice);
34 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
35 return true;
36 }
37 // NHWC <-> NC4HW4: Turn NHWC to NCHW
38 // TODO for multi input can find better way to compute new slice
39 MNN_ASSERT(4 == input->dimensions());
40 auto inside = input->width() * input->height();
41 auto axis = input->channel();
42 auto outside = input->batch();
43 auto swap = [](Tensor::InsideDescribe::Region& inp) {
44 auto tempStride = inp.src.stride[2];
45 inp.src.stride[2] = inp.src.stride[1];
46 inp.src.stride[1] = tempStride;
47 auto tempSize = inp.size[2];
48 inp.size[2] = inp.size[1];
49 inp.size[1] = tempSize;
50 inp.dst.stride[2] = 1;
51 inp.dst.stride[1] = inp.size[2];
52 };
53 if (inputSlice.size() == 1) {
54 auto& inp = inputSlice[0];
55 bool canReshape = false;
56 if (inputFormat == MNN_DATA_FORMAT_NCHW) {
57 canReshape = TensorUtils::reshapeSlice(inp, outside, inside, axis);
58 } else {
59 canReshape = TensorUtils::reshapeSlice(inp, outside, axis, inside);
60 }
61 if (canReshape) {
62 swap(inp);
63 outputDes->regions = std::move(inputSlice);
64 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
65 return true;
66 }
67 }
68 auto slice = TensorUtils::makeFullSlice(input);
69 if (inputFormat == MNN_DATA_FORMAT_NCHW) {
70 TensorUtils::reshapeSlice(slice, outside, inside, axis);
71 } else {
72 TensorUtils::reshapeSlice(slice, outside, axis, inside);
73 }
74 swap(slice);
75
76 outputDes->regions = {slice};
77 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
78
79 return true;
80 }
81
broadcastto(Tensor * input,Tensor * output)82 void ConvertUtils::broadcastto(Tensor* input, Tensor* output) {
83 auto inputDes = TensorUtils::getDescribe(input);
84 auto outputDes = TensorUtils::getDescribe(output);
85 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
86 if (input->elementSize() == output->elementSize()) {
87 // Just Copy Tensor
88 auto inputSlice = inputDes->regions;
89 if (inputSlice.empty()) {
90 // Create Full Refence
91 Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(input);
92 inputSlice.emplace_back(std::move(totalSlice));
93 }
94 outputDes->regions = std::move(inputSlice);
95 return;
96 }
97 int32_t inputShape[MNN_MAX_TENSOR_DIM];
98 auto outputDim = output->dimensions();
99 for (int i=0; i<outputDim; ++i) {
100 inputShape[i] = 1;
101 }
102 int offset = outputDim - input->dimensions();
103 for (int i = 0; i < input->dimensions(); ++i) {
104 inputShape[i + offset] = input->length(i);
105 }
106 // Compute Strides
107 int sepInputShapeSize = 0;
108 int sepOutputShapeSize = 0;
109 int sepInputShape[MNN_MAX_TENSOR_DIM];
110 int sepOutputShape[MNN_MAX_TENSOR_DIM];
111 int currentInput = 1;
112 int currentOutput = 1;
113 for (int i = 0; i < outputDim; ++i) {
114 if (inputShape[i] != output->length(i)) {
115 if (1 < currentOutput) {
116 sepInputShape[sepInputShapeSize++] = currentInput;
117 sepOutputShape[sepOutputShapeSize++] = currentOutput;
118 }
119 sepInputShape[sepInputShapeSize++] = (inputShape[i]);
120 sepOutputShape[sepOutputShapeSize++] = (output->length(i));
121 currentInput = 1;
122 currentOutput = 1;
123 } else {
124 currentInput *= inputShape[i];
125 currentOutput *= output->length(i);
126 }
127 }
128 if (currentOutput != 1 || currentInput != 1) {
129 sepInputShape[sepInputShapeSize++] = (currentInput);
130 sepOutputShape[sepOutputShapeSize++] = (currentOutput);
131 }
132 int seperateOutputStrides[MNN_MAX_TENSOR_DIM];
133 int seperateInputStrides[MNN_MAX_TENSOR_DIM];
134 OpCommonUtils::computeStride(seperateOutputStrides, sepOutputShape, sepOutputShapeSize);
135 OpCommonUtils::computeStride(seperateInputStrides, sepInputShape, sepInputShapeSize);
136 for (int i = 0; i < sepInputShapeSize; ++i) {
137 if (1 == sepInputShape[i]) {
138 seperateInputStrides[i] = 0;
139 }
140 }
141
142 // Split region by size, use stride to determine src and dst mapping
143 int remainDimSize = sepInputShapeSize > 3 ? (int)sepInputShapeSize - 3 : 0;
144 std::vector<int> remainStride(remainDimSize + 1);
145 int remainSize = OpCommonUtils::computeStride(remainStride.data(), sepOutputShape, remainDimSize);
146 outputDes->regions.resize(remainSize);
147 std::vector<int> cords(remainDimSize + 1);
148 for (int index = 0; index < remainSize; ++index) {
149 OpCommonUtils::unravelIndexHelper(cords, remainStride, remainDimSize, index);
150 auto& reg = outputDes->regions[index];
151 for (int i = 0; i < remainDimSize; ++i) {
152 reg.src.offset += (cords[i] * seperateInputStrides[i]);
153 reg.dst.offset += (cords[i] * seperateOutputStrides[i]);
154 }
155 reg.origin = input;
156 for (int i = 0; i < 3; ++i) {
157 auto match = (int)sepOutputShapeSize - i - 1;
158 if (match < 0) {
159 continue;
160 }
161 reg.size[3 - i - 1] = sepOutputShape[match];
162 reg.src.stride[3 - i - 1] = seperateInputStrides[match];
163 reg.dst.stride[3 - i - 1] = seperateOutputStrides[match];
164 }
165 }
166 }
167
168 } // namespace MNN
169