1 //
2 //  ConvertUtils.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/04/03.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "ConvertUtils.hpp"
10 #include "core/OpCommonUtils.hpp"
11 namespace MNN {
compute(Tensor * input,Tensor * output,CommandBuffer & res)12 bool ConvertUtils::compute(Tensor* input, Tensor* output, CommandBuffer& res) {
13     auto inputDes     = TensorUtils::getDescribe(input);
14     auto outputDes    = TensorUtils::getDescribe(output);
15     auto inputFormat  = inputDes->dimensionFormat;
16     auto outputFormat = outputDes->dimensionFormat;
17     if (MNN_DATA_FORMAT_NC4HW4 == inputFormat) {
18         inputFormat = MNN_DATA_FORMAT_NCHW;
19     }
20     if (MNN_DATA_FORMAT_NC4HW4 == outputFormat) {
21         outputFormat = MNN_DATA_FORMAT_NCHW;
22     }
23     auto inputSlice = inputDes->regions;
24     MNN_ASSERT(input->dimensions() >= 1);
25     MNN_ASSERT(output->dimensions() == input->dimensions());
26     if (inputSlice.empty()) {
27         inputSlice.resize(1);
28         // Create Full Refence
29         inputSlice[0] = TensorUtils::makeFullSlice(input);
30     }
31     if (inputFormat == outputFormat || 2 == input->dimensions()) {
32         // No need for treat for NCWH <-> NC4HW4
33         outputDes->regions    = std::move(inputSlice);
34         outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
35         return true;
36     }
37     // NHWC <-> NC4HW4: Turn NHWC to NCHW
38     // TODO for multi input can find better way to compute new slice
39     MNN_ASSERT(4 == input->dimensions());
40     auto inside  = input->width() * input->height();
41     auto axis    = input->channel();
42     auto outside = input->batch();
43     auto swap    = [](Tensor::InsideDescribe::Region& inp) {
44         auto tempStride   = inp.src.stride[2];
45         inp.src.stride[2] = inp.src.stride[1];
46         inp.src.stride[1] = tempStride;
47         auto tempSize     = inp.size[2];
48         inp.size[2]       = inp.size[1];
49         inp.size[1]       = tempSize;
50         inp.dst.stride[2] = 1;
51         inp.dst.stride[1] = inp.size[2];
52     };
53     if (inputSlice.size() == 1) {
54         auto& inp       = inputSlice[0];
55         bool canReshape = false;
56         if (inputFormat == MNN_DATA_FORMAT_NCHW) {
57             canReshape = TensorUtils::reshapeSlice(inp, outside, inside, axis);
58         } else {
59             canReshape = TensorUtils::reshapeSlice(inp, outside, axis, inside);
60         }
61         if (canReshape) {
62             swap(inp);
63             outputDes->regions    = std::move(inputSlice);
64             outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
65             return true;
66         }
67     }
68     auto slice = TensorUtils::makeFullSlice(input);
69     if (inputFormat == MNN_DATA_FORMAT_NCHW) {
70         TensorUtils::reshapeSlice(slice, outside, inside, axis);
71     } else {
72         TensorUtils::reshapeSlice(slice, outside, axis, inside);
73     }
74     swap(slice);
75 
76     outputDes->regions    = {slice};
77     outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
78 
79     return true;
80 }
81 
broadcastto(Tensor * input,Tensor * output)82 void ConvertUtils::broadcastto(Tensor* input, Tensor* output) {
83     auto inputDes         = TensorUtils::getDescribe(input);
84     auto outputDes        = TensorUtils::getDescribe(output);
85     outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
86     if (input->elementSize() == output->elementSize()) {
87         // Just Copy Tensor
88         auto inputSlice = inputDes->regions;
89         if (inputSlice.empty()) {
90             // Create Full Refence
91             Tensor::InsideDescribe::Region totalSlice = TensorUtils::makeFullSlice(input);
92             inputSlice.emplace_back(std::move(totalSlice));
93         }
94         outputDes->regions = std::move(inputSlice);
95         return;
96     }
97     int32_t inputShape[MNN_MAX_TENSOR_DIM];
98     auto outputDim = output->dimensions();
99     for (int i=0; i<outputDim; ++i) {
100         inputShape[i] = 1;
101     }
102     int offset = outputDim - input->dimensions();
103     for (int i = 0; i < input->dimensions(); ++i) {
104         inputShape[i + offset] = input->length(i);
105     }
106     // Compute Strides
107     int sepInputShapeSize = 0;
108     int sepOutputShapeSize = 0;
109     int sepInputShape[MNN_MAX_TENSOR_DIM];
110     int sepOutputShape[MNN_MAX_TENSOR_DIM];
111     int currentInput  = 1;
112     int currentOutput = 1;
113     for (int i = 0; i < outputDim; ++i) {
114         if (inputShape[i] != output->length(i)) {
115             if (1 < currentOutput) {
116                 sepInputShape[sepInputShapeSize++] = currentInput;
117                 sepOutputShape[sepOutputShapeSize++] = currentOutput;
118             }
119             sepInputShape[sepInputShapeSize++] = (inputShape[i]);
120             sepOutputShape[sepOutputShapeSize++] = (output->length(i));
121             currentInput  = 1;
122             currentOutput = 1;
123         } else {
124             currentInput *= inputShape[i];
125             currentOutput *= output->length(i);
126         }
127     }
128     if (currentOutput != 1 || currentInput != 1) {
129         sepInputShape[sepInputShapeSize++] = (currentInput);
130         sepOutputShape[sepOutputShapeSize++] = (currentOutput);
131     }
132     int seperateOutputStrides[MNN_MAX_TENSOR_DIM];
133     int seperateInputStrides[MNN_MAX_TENSOR_DIM];
134     OpCommonUtils::computeStride(seperateOutputStrides, sepOutputShape, sepOutputShapeSize);
135     OpCommonUtils::computeStride(seperateInputStrides, sepInputShape, sepInputShapeSize);
136     for (int i = 0; i < sepInputShapeSize; ++i) {
137         if (1 == sepInputShape[i]) {
138             seperateInputStrides[i] = 0;
139         }
140     }
141 
142     // Split region by size, use stride to determine src and dst mapping
143     int remainDimSize = sepInputShapeSize > 3 ? (int)sepInputShapeSize - 3 : 0;
144     std::vector<int> remainStride(remainDimSize + 1);
145     int remainSize = OpCommonUtils::computeStride(remainStride.data(), sepOutputShape, remainDimSize);
146     outputDes->regions.resize(remainSize);
147     std::vector<int> cords(remainDimSize + 1);
148     for (int index = 0; index < remainSize; ++index) {
149         OpCommonUtils::unravelIndexHelper(cords, remainStride, remainDimSize, index);
150         auto& reg = outputDes->regions[index];
151         for (int i = 0; i < remainDimSize; ++i) {
152             reg.src.offset += (cords[i] * seperateInputStrides[i]);
153             reg.dst.offset += (cords[i] * seperateOutputStrides[i]);
154         }
155         reg.origin = input;
156         for (int i = 0; i < 3; ++i) {
157             auto match = (int)sepOutputShapeSize - i - 1;
158             if (match < 0) {
159                 continue;
160             }
161             reg.size[3 - i - 1]       = sepOutputShape[match];
162             reg.src.stride[3 - i - 1] = seperateInputStrides[match];
163             reg.dst.stride[3 - i - 1] = seperateOutputStrides[match];
164         }
165     }
166 }
167 
168 } // namespace MNN
169