1 //
2 // GeometryImageOp.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/05/07.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "ConvertUtils.hpp"
10 #include "geometry/GeometryComputer.hpp"
11 #include "geometry/GeometryComputerUtils.hpp"
12 #include "shape/SizeComputer.hpp"
13 namespace MNN {
14
15
16 /**
17 if coordinate_transformation_mode is "half_pixel",
18 x_original = (x_resized + 0.5) / scale - 0.5,
19
20 if coordinate_transformation_mode is "pytorch_half_pixel",
21 x_original = length_resized > 1 ? (x_resized + 0.5) / scale - 0.5 : 0,
22
23 if coordinate_transformation_mode is "align_corners",
24 x_original = x_resized * (length_original - 1) / (length_resized - 1),
25
26 if coordinate_transformation_mode is "asymmetric",
27 x_original = x_resized / scale,
28
29 if coordinate_transformation_mode is "tf_half_pixel_for_nn",
30 x_original = (x_resized + 0.5) / scale,
31
32 if coordinate_transformation_mode is "tf_crop_and_resize",
33 x_original = length_resized > 1 ? start_x * (length_original - 1) + x_resized * (end_x - start_x) * (length_original - 1) / (length_resized - 1) : 0.5 * (start_x + end_x) * (length_original - 1).
34 */
35 struct InterpInfo {
36 float heightScale;
37 float widthScale;
38 float widthOffset = 0.0f;
39 float heightOffset = 0.0f;
40 };
_ConverterInterp(const Interp * resize,InterpInfo * dstInfo,int inW,int inH,int outW,int outH)41 static void _ConverterInterp(const Interp* resize, InterpInfo* dstInfo, int inW, int inH, int outW, int outH) {
42 switch (resize->ctm()) {
43 case CoordinateTransformationMode_NotSet:
44 {
45 // For compability, old model's nearest don't support halfpixels
46 if (resize->halfPixelCenters() && resize->resizeType() != 1) {
47 dstInfo->heightScale = (float)(inH) / (float)(outH);
48 dstInfo->widthScale = (float)(inW) / (float)(outW);
49 dstInfo->widthOffset = 0.5f * dstInfo->widthScale - 0.5f;
50 dstInfo->heightOffset = 0.5f * dstInfo->heightScale - 0.5f;
51 } else if (resize->alignCorners()) {
52 if (outH == 1) {
53 dstInfo->heightScale = 0.0f;
54 } else {
55 dstInfo->heightScale = (float)(inH - 1) / (float)(outH - 1);
56 }
57 if (outW == 1) {
58 dstInfo->widthScale = 0.0f;
59 } else {
60 dstInfo->widthScale = (float)(inW - 1) / (float)(outW - 1);
61 }
62 } else {
63 dstInfo->heightScale = (float)(inH) / (float)(outH);
64 dstInfo->widthScale = (float)(inW) / (float)(outW);
65 }
66 break;
67 }
68 case CoordinateTransformationMode_AlignCorners:
69 {
70 if (outH == 1) {
71 dstInfo->heightScale = 0.0f;
72 } else {
73 dstInfo->heightScale = (float)(inH - 1) / (float)(outH - 1);
74 }
75 if (outW == 1) {
76 dstInfo->widthScale = 0.0f;
77 } else {
78 dstInfo->widthScale = (float)(inW - 1) / (float)(outW - 1);
79 }
80 break;
81 }
82 case CoordinateTransformationMode_HalfPixels:
83 {
84 dstInfo->heightScale = (float)(inH) / (float)(outH);
85 dstInfo->widthScale = (float)(inW) / (float)(outW);
86 dstInfo->widthOffset = 0.5f * dstInfo->widthScale - 0.5f;
87 dstInfo->heightOffset = 0.5f * dstInfo->heightScale - 0.5f;
88 break;
89 }
90 case CoordinateTransformationMode_PytorchHalfPixels:
91 {
92 if (outH > 1) {
93 dstInfo->heightScale = (float)inH / (float)outH;
94 dstInfo->heightOffset = 0.5f * dstInfo->heightScale - 0.5f;
95 } else {
96 dstInfo->heightScale = 0.0f;
97 }
98 if (outW > 1) {
99 dstInfo->widthScale = (float)inW / (float)outW;
100 dstInfo->widthOffset = 0.5f * dstInfo->widthScale - 0.5f;
101 } else {
102 dstInfo->widthScale = 0.0f;
103 }
104 break;
105 }
106 case CoordinateTransformationMode_Asymmetric:
107 {
108 dstInfo->heightScale = (float)(inH) / (float)(outH);
109 dstInfo->widthScale = (float)(inW) / (float)(outW);
110 break;
111 }
112 case CoordinateTransformationMode_TensorflowHalfPixels:
113 {
114 dstInfo->heightScale = (float)(inH) / (float)(outH);
115 dstInfo->widthScale = (float)(inW) / (float)(outW);
116 dstInfo->widthOffset = 0.5f * dstInfo->widthScale;
117 dstInfo->heightOffset = 0.5f * dstInfo->heightScale;
118 break;
119 }
120 case CoordinateTransformationMode_TensorflowCropAndResize:
121 {
122 //FIXME: Not support now
123 MNN_ERROR("Don't support CoordinateTransformationMode_TensorflowCropAndResize currently\n");
124 break;
125 }
126 default:
127 break;
128 }
129 }
makeInterp(flatbuffers::FlatBufferBuilder & builder,const InterpInfo * info,int resizeType,const Op * op)130 static flatbuffers::Offset<Op> makeInterp(flatbuffers::FlatBufferBuilder& builder, const InterpInfo* info, int resizeType, const Op* op) {
131 flatbuffers::Offset<flatbuffers::String> temp;
132 if (nullptr != op->name()) {
133 temp = builder.CreateString(op->name()->str());
134 }
135 InterpBuilder intpB(builder);
136 intpB.add_resizeType(resizeType);
137 intpB.add_widthScale(info->widthScale);
138 intpB.add_heightScale(info->heightScale);
139 intpB.add_heightOffset(info->heightOffset);
140 intpB.add_widthOffset(info->widthOffset);
141 auto offsetInterp = intpB.Finish().Union();
142 OpBuilder opB(builder);
143 opB.add_type(OpType_Interp);
144 opB.add_main(offsetInterp);
145 opB.add_main_type(OpParameter_Interp);
146 if (nullptr != op->name()) {
147 opB.add_name(temp);
148 }
149 return opB.Finish();
150 }
151
152 class GeometryImageOp : public GeometryComputer {
153 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const154 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
155 Context& context, CommandBuffer& res) const override {
156 auto newOutputs = outputs;
157 auto newInputs = inputs;
158 auto originOutput = outputs[0];
159 auto output = originOutput;
160 auto inputDes = TensorUtils::getDescribe(newInputs[0]);
161 auto format = inputDes->dimensionFormat;
162 if (MNN_DATA_FORMAT_NC4HW4 != format) {
163 std::shared_ptr<Tensor> newInput(new Tensor(newInputs[0], Tensor::CAFFE_C4, false));
164 ConvertUtils::compute(newInputs[0], newInput.get(), res);
165 newInputs[0] = newInput.get();
166 res.extras.emplace_back(std::move(newInput));
167 std::shared_ptr<Tensor> newOutput(new Tensor(originOutput, Tensor::CAFFE_C4, false));
168 output = newOutput.get();
169 newOutputs[0] = output;
170 res.extras.emplace_back(newOutput);
171 }
172 if (OpType_Resize == op->type()) {
173 // Turn resize to interp
174 InterpInfo info;
175 info.widthScale = (float)inputs[0]->width() / (float)outputs[0]->width();
176 info.heightScale = (float)inputs[0]->height() / (float)outputs[0]->height();
177 flatbuffers::FlatBufferBuilder builder;
178 builder.Finish(makeInterp(builder, &info, 2, op));
179 res.command.emplace_back(GeometryComputerUtils::makeCommand(builder, {newInputs[0]}, newOutputs));
180 }
181 else if (OpType_Interp == op->type()) {
182 // Compute cord transform for interp
183 auto resize = op->main_as_Interp();
184 auto inW = inputs[0]->width();
185 auto inH = inputs[0]->height();
186 auto outW = outputs[0]->width();
187 auto outH = outputs[0]->height();
188 InterpInfo info;
189 _ConverterInterp(resize, &info, inW, inH, outW, outH);
190 flatbuffers::FlatBufferBuilder builder;
191 builder.Finish(makeInterp(builder, &info, resize->resizeType(), op));
192 res.command.emplace_back(GeometryComputerUtils::makeCommand(builder, {newInputs[0]}, newOutputs));
193 } else {
194 Command cmd;
195 cmd.op = op;
196 cmd.inputs = std::move(newInputs);
197 cmd.outputs = std::move(newOutputs);
198 res.command.emplace_back(std::move(cmd));
199 }
200 if (originOutput != output) {
201 ConvertUtils::compute(output, originOutput, res);
202 }
203 return true;
204 }
205 };
206
_create()207 static void _create() {
208 std::shared_ptr<GeometryComputer> comp(new GeometryImageOp);
209 GeometryComputer::registerGeometryComputer(
210 comp, {
211 OpType_ConvInt8,
212 OpType_DepthwiseConvInt8,
213 OpType_ConvolutionDepthwise,
214 OpType_DeconvolutionDepthwise,
215 OpType_Pooling,
216 OpType_Interp,
217 OpType_Resize,
218 OpType_Int8ToFloat,
219 OpType_FloatToInt8
220 });
221 }
222
223 REGISTER_GEOMETRY(GeometryImageOp, _create);
224
225 } // namespace MNN
226