1 //
2 //  GeometryConvUtils.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/07/15.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "GeometryConvUtils.hpp"
10 #include "ConvertUtils.hpp"
11 
12 #define ADD_PAD_VALUE(POS, OFFSET, NUM, STRIDE)               \
13     if (POS##Pad > 0) {                                       \
14         Tensor::InsideDescribe::Region region;                \
15         region.origin        = padVal;                        \
16         region.size[0]       = ic;                            \
17         region.size[1]       = NUM;                           \
18         region.size[2]       = POS##Pad;                      \
19         region.src.offset    = 0;                             \
20         region.dst.offset    = dstOffsetKx + (OFFSET);        \
21         region.src.stride[0] = 0;                             \
22         region.src.stride[1] = 0;                             \
23         region.src.stride[2] = 0;                             \
24         region.dst.stride[0] = dstStrideChannel;              \
25         region.dst.stride[1] = STRIDE;                        \
26         region.dst.stride[2] = 1;                             \
27         des->regions.emplace_back(std::move(region));         \
28     }
29 namespace MNN {
makeRelu6(flatbuffers::FlatBufferBuilder & builder,float minValue,float maxValue)30 flatbuffers::Offset<Op> GeometryConvUtils::makeRelu6(flatbuffers::FlatBufferBuilder& builder, float minValue, float maxValue) {
31     Relu6Builder relu6B(builder);
32     relu6B.add_maxValue(maxValue);
33     relu6B.add_minValue(minValue);
34     auto paOffset = relu6B.Finish().Union();
35     OpBuilder opB(builder);
36     opB.add_type(OpType_ReLU6);
37     opB.add_main_type(OpParameter_Relu6);
38     opB.add_main(paOffset);
39     return opB.Finish();
40 }
im2Col3d(Tensor * im2Col,Tensor * input,int ic,int kd,int kh,int kw,int batch,int od,int oh,int ow,int id,int ih,int iw,int sd,int sh,int sw,int dd,int dh,int dw,int pd,int ph,int pw,int srcKernelOffset)41 void GeometryConvUtils::im2Col3d(Tensor* im2Col, Tensor* input, int ic, int kd, int kh, int kw, int batch, int od, int oh, int ow,
42     int id, int ih, int iw, int sd, int sh, int sw, int dd, int dh, int dw, int pd, int ph, int pw, int srcKernelOffset) {
43     im2Col->buffer().type       = halide_type_of<float>();
44     im2Col->buffer().dimensions = 2;
45     im2Col->setLength(0, ic * kd * kh * kw);
46     im2Col->setLength(1, batch * od * oh * ow);
47     TensorUtils::setLinearLayout(im2Col);
48     auto des             = TensorUtils::getDescribe(im2Col);
49     des->memoryType      = Tensor::InsideDescribe::MEMORY_VIRTUAL;
50     des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
51     des->regions.clear();
52     des->regions.reserve(batch * ic * kd * kh * kw);
53     for (int c = 0; c < ic; ++c) {
54         for (int n = 0; n < batch; ++n) {
55             auto dstOffset = (c * kd * kh * kw * batch + n) * od * oh * ow;
56             // auto dstOffset = (n * ic + c) * od * oh * ow * kd * kh * kw;
57             auto srcOffset = (n * ic + c) * id * ih * iw;
58             for (int kz = 0; kz < kd; ++kz) {
59                 auto startSz = kz * dd - pd;
60                 int startDz = 0;
61                 if (startSz < 0) {
62                     startDz = ((-startSz) + sd - 1) / sd;
63                     startSz = startSz + startDz * sd;
64                 }
65                 auto endDz = od - 1;
66                 auto endSz = endDz * sd + kz * dd - pd;
67                 if (endSz >= id) {
68                     endDz = endDz - (endSz - id + sd) / sd;
69                     endSz = endDz * sd + kz * dd - pd;
70                 }
71                 if (startDz > endDz || endDz < 0 || startSz >= id) {
72                     continue;
73                 }
74                 auto dstOffsetKz = dstOffset + kz * kw * kh * ow * oh * od * batch + startDz * oh *  ow;
75                 auto srcOffsetKz = srcOffset + startSz * ih * iw;
76                 for (int ky = 0; ky < kh; ++ky) {
77                     auto startSy = ky * dh - ph;
78                     int startDy  = 0;
79                     if (startSy < 0) {
80                         startDy = ((-startSy) + sh - 1) / sh;
81                         startSy = startSy + startDy * sh;
82                     }
83                     auto endDy = oh - 1;
84                     auto endSy = endDy * sh + ky * dh - ph;
85                     if (endSy >= ih) {
86                         endDy = endDy - (endSy - ih + sh) / sh;
87                         endSy = endDy * sh + ky * dh - ph;
88                     }
89                     if (startDy > endDy || endDy < 0 || startSy >= ih) {
90                         continue;
91                     }
92                     auto dstOffsetKy = dstOffsetKz + ky * kw * ow * oh * od * batch + startDy * ow;
93                     auto srcOffsetKy = srcOffsetKz + startSy * iw;
94                     for (int kx = 0; kx < kw; ++kx) {
95                         auto startSx = kx * dw - pw;
96                         int startDx  = 0;
97                         if (startSx < 0) {
98                             startDx = ((-startSx) + sw - 1) / sw;
99                             startSx = startSx + startDx * sw;
100                         }
101                         auto endDx = ow - 1;
102                         auto endSx = endDx * sw + kx * dw - pw;
103                         if (endSx >= iw) {
104                             endDx = endDx - (endSx - iw + sw) / sw;
105                             endSx = endDx * sw + kx * dw - pw;
106                         }
107                         if (startDx > endDx || endDx < 0 || startSx >= iw) {
108                             continue;
109                         }
110                         auto dstOffsetKx = dstOffsetKy + kx * od * oh * ow * batch + startDx;
111                         auto srcOffsetKx = srcOffsetKy + startSx + srcKernelOffset * (kx + ky * kw);
112                         Tensor::InsideDescribe::Region region;
113                         region.origin        = input;
114                         region.size[0]       = endDz - startDz + 1;
115                         region.size[1]       = endDy - startDy + 1;
116                         region.size[2]       = endDx - startDx + 1;
117                         region.src.offset    = srcOffsetKx;
118                         region.dst.offset    = dstOffsetKx;
119                         region.src.stride[0] = sd * ih * iw;
120                         region.dst.stride[0] = oh * ow;
121                         region.src.stride[1] = sh * iw;
122                         region.dst.stride[1] = ow;
123                         region.src.stride[2] = sw;
124                         region.dst.stride[2] = 1;
125                         des->regions.emplace_back(std::move(region));
126                     }
127                 }
128                 // MNN_ASSERT(des->regions.size() > 0);
129             }
130         }
131     }
132 }
im2Col(Tensor * im2Col,Tensor * input,int ic,int kh,int kw,int batch,int oh,int ow,int ih,int iw,int sh,int sw,int dh,int dw,std::pair<int,int> pads,int srcKernelOffset,Tensor * padVal)133 void GeometryConvUtils::im2Col(Tensor* im2Col, Tensor* input, int ic, int kh, int kw, int batch, int oh, int ow, int ih,
134                                int iw, int sh, int sw, int dh, int dw, std::pair<int, int> pads, int srcKernelOffset, Tensor* padVal) {
135     im2Col->buffer().type       = halide_type_of<float>();
136     im2Col->buffer().dimensions = 2;
137     im2Col->setLength(0, ic * kw * kh);
138     im2Col->setLength(1, batch * ow * oh);
139     TensorUtils::setLinearLayout(im2Col);
140     auto des             = TensorUtils::getDescribe(im2Col);
141     des->memoryType      = Tensor::InsideDescribe::MEMORY_VIRTUAL;
142     des->dimensionFormat = MNN_DATA_FORMAT_NCHW;
143     des->regions.clear();
144     if (padVal == nullptr) {
145         des->regions.reserve(batch * kw * kh);
146     }
147     int dstStrideChannel = batch * oh * ow * kh * kw;
148     int srcStrideChannel = iw * ih;
149     for (int n = 0; n < batch; ++n) {
150         auto dstOffset = ow * oh * n;
151         auto srcOffset = n * ic * iw * ih;
152         for (int ky = 0; ky < kh; ++ky) {
153             auto startSy = ky * dh - pads.second;
154             int startDy  = 0;
155             int upPad = 0, belowPad = 0;
156             if (startSy < 0) {
157                 startDy = ((-startSy) + sh - 1) / sh;
158                 startSy = startSy + startDy * sh;
159                 upPad = startDy * ow;
160             }
161             auto endDy = oh - 1;
162             auto endSy = endDy * sh + ky * dh - pads.second;
163             if (endSy >= ih) {
164                 endDy = endDy - (endSy - ih + sh) / sh;
165                 endSy = endDy * sh + ky * dh - pads.second;
166                 belowPad = (oh - endDy - 1) * ow;
167             }
168             if (startDy > endDy || endDy < 0 || startSy >= ih) {
169                 continue;
170             }
171             auto dstOffsetKy = dstOffset + ky * kw * ow * oh * batch + startDy * ow;
172             auto srcOffsetKy = srcOffset + startSy * iw;
173             for (int kx = 0; kx < kw; ++kx) {
174                 auto startSx = kx * dw - pads.first;
175                 int startDx  = 0;
176                 int leftPad = 0, rightPad = 0;
177                 if (startSx < 0) {
178                     startDx = ((-startSx) + sw - 1) / sw;
179                     startSx = startSx + startDx * sw;
180                     leftPad = startDx;
181                 }
182                 auto endDx = ow - 1;
183                 auto endSx = endDx * sw + kx * dw - pads.first;
184                 if (endSx >= iw) {
185                     endDx = endDx - (endSx - iw + sw) / sw;
186                     endSx = endDx * sw + kx * dw - pads.first;
187                     rightPad = ow - endDx - 1;
188                 }
189                 if (startDx > endDx || endDx < 0 || startSx >= iw) {
190                     continue;
191                 }
192                 auto dstOffsetKx = dstOffsetKy + kx * ow * oh * batch + startDx;
193                 auto srcOffsetKx = srcOffsetKy + startSx + srcKernelOffset * (kx + ky * kw);
194                 const int ohExcludePad = endDy - startDy + 1;
195                 const int owExcludePad = endDx - startDx + 1;
196                 // if given padVal, pad value will use padVa otherwise use zero
197                 if (padVal) {
198                     ADD_PAD_VALUE(up, -(startDx+upPad), 1, 0);
199                     ADD_PAD_VALUE(below, ohExcludePad * ow - startDx, 1, 0);
200                     ADD_PAD_VALUE(left, -leftPad, ohExcludePad, ow);
201                     ADD_PAD_VALUE(right, owExcludePad, ohExcludePad, ow);
202                 }
203                 Tensor::InsideDescribe::Region region;
204                 region.origin        = input;
205                 region.size[0]       = ic;
206                 region.size[1]       = ohExcludePad;
207                 region.size[2]       = owExcludePad;
208                 region.src.offset    = srcOffsetKx;
209                 region.dst.offset    = dstOffsetKx;
210                 region.src.stride[0] = srcStrideChannel;
211                 region.dst.stride[0] = dstStrideChannel;
212                 region.src.stride[1] = sh * iw;
213                 region.dst.stride[1] = ow;
214                 region.src.stride[2] = sw;
215                 region.dst.stride[2] = 1;
216                 des->regions.emplace_back(std::move(region));
217             }
218         }
219         // MNN_ASSERT(des->regions.size() > 0);
220     }
221 }
computeSingle(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,GeometryComputer::Context & context,CommandBuffer & res)222 bool GeometryConvUtils::computeSingle(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, GeometryComputer::Context& context, CommandBuffer& res) {
223     auto newOutputs   = outputs;
224     auto newInputs    = inputs;
225     auto originOutput = outputs[0];
226     auto output       = originOutput;
227     auto inputDes     = TensorUtils::getDescribe(newInputs[0]);
228     auto format       = inputDes->dimensionFormat;
229     if (MNN_DATA_FORMAT_NC4HW4 != format) {
230         std::shared_ptr<Tensor> newInput(new Tensor(newInputs[0], Tensor::CAFFE_C4, false));
231         ConvertUtils::compute(newInputs[0], newInput.get(), res);
232         newInputs[0] = newInput.get();
233         res.extras.emplace_back(std::move(newInput));
234         std::shared_ptr<Tensor> newOutput(new Tensor(originOutput, Tensor::CAFFE_C4, false));
235         output        = newOutput.get();
236         newOutputs[0] = output;
237         res.extras.emplace_back(newOutput);
238     }
239     Command cmd;
240     cmd.op      = op;
241     cmd.inputs  = std::move(newInputs);
242     cmd.outputs = std::move(newOutputs);
243     res.command.emplace_back(std::move(cmd));
244     if (originOutput != output) {
245         ConvertUtils::compute(output, originOutput, res);
246     }
247     return true;
248 }
249 }; // namespace MNN
250