1 //
2 //  GeometryImageOp.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/05/07.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "ConvertUtils.hpp"
10 #include "geometry/GeometryComputer.hpp"
11 #include "geometry/GeometryComputerUtils.hpp"
12 #include "shape/SizeComputer.hpp"
13 namespace MNN {
14 
15 
16 /**
17  if coordinate_transformation_mode is "half_pixel",
18  x_original = (x_resized + 0.5) / scale - 0.5,
19 
20  if coordinate_transformation_mode is "pytorch_half_pixel",
21  x_original = length_resized > 1 ? (x_resized + 0.5) / scale - 0.5 : 0,
22 
23  if coordinate_transformation_mode is "align_corners",
24  x_original = x_resized * (length_original - 1) / (length_resized - 1),
25 
26  if coordinate_transformation_mode is "asymmetric",
27  x_original = x_resized / scale,
28 
29  if coordinate_transformation_mode is "tf_half_pixel_for_nn",
30  x_original = (x_resized + 0.5) / scale,
31 
32  if coordinate_transformation_mode is "tf_crop_and_resize",
33  x_original = length_resized > 1 ? start_x * (length_original - 1) + x_resized * (end_x - start_x) * (length_original - 1) / (length_resized - 1) : 0.5 * (start_x + end_x) * (length_original - 1).
34  */
35 struct InterpInfo {
36     float heightScale;
37     float widthScale;
38     float widthOffset = 0.0f;
39     float heightOffset = 0.0f;
40 };
_ConverterInterp(const Interp * resize,InterpInfo * dstInfo,int inW,int inH,int outW,int outH)41 static void _ConverterInterp(const Interp* resize, InterpInfo* dstInfo, int inW, int inH, int outW, int outH) {
42     switch (resize->ctm()) {
43         case CoordinateTransformationMode_NotSet:
44         {
45             // For compability, old model's nearest don't support halfpixels
46             if (resize->halfPixelCenters() && resize->resizeType() != 1) {
47                 dstInfo->heightScale = (float)(inH) / (float)(outH);
48                 dstInfo->widthScale  = (float)(inW) / (float)(outW);
49                 dstInfo->widthOffset = 0.5f * dstInfo->widthScale - 0.5f;
50                 dstInfo->heightOffset = 0.5f * dstInfo->heightScale - 0.5f;
51             } else if (resize->alignCorners()) {
52                 if (outH == 1) {
53                     dstInfo->heightScale = 0.0f;
54                 } else {
55                     dstInfo->heightScale = (float)(inH - 1) / (float)(outH - 1);
56                 }
57                 if (outW == 1) {
58                     dstInfo->widthScale = 0.0f;
59                 } else {
60                     dstInfo->widthScale  = (float)(inW - 1) / (float)(outW - 1);
61                 }
62             } else {
63                 dstInfo->heightScale = (float)(inH) / (float)(outH);
64                 dstInfo->widthScale  = (float)(inW) / (float)(outW);
65             }
66             break;
67         }
68         case CoordinateTransformationMode_AlignCorners:
69         {
70             if (outH == 1) {
71                 dstInfo->heightScale = 0.0f;
72             } else {
73                 dstInfo->heightScale = (float)(inH - 1) / (float)(outH - 1);
74             }
75             if (outW == 1) {
76                 dstInfo->widthScale = 0.0f;
77             } else {
78                 dstInfo->widthScale  = (float)(inW - 1) / (float)(outW - 1);
79             }
80             break;
81         }
82         case CoordinateTransformationMode_HalfPixels:
83         {
84             dstInfo->heightScale = (float)(inH) / (float)(outH);
85             dstInfo->widthScale  = (float)(inW) / (float)(outW);
86             dstInfo->widthOffset = 0.5f * dstInfo->widthScale - 0.5f;
87             dstInfo->heightOffset = 0.5f * dstInfo->heightScale - 0.5f;
88             break;
89         }
90         case CoordinateTransformationMode_PytorchHalfPixels:
91         {
92             if (outH > 1) {
93                 dstInfo->heightScale = (float)inH / (float)outH;
94                 dstInfo->heightOffset = 0.5f * dstInfo->heightScale - 0.5f;
95             } else {
96                 dstInfo->heightScale = 0.0f;
97             }
98             if (outW > 1) {
99                 dstInfo->widthScale = (float)inW / (float)outW;
100                 dstInfo->widthOffset = 0.5f * dstInfo->widthScale - 0.5f;
101             } else {
102                 dstInfo->widthScale = 0.0f;
103             }
104             break;
105         }
106         case CoordinateTransformationMode_Asymmetric:
107         {
108             dstInfo->heightScale = (float)(inH) / (float)(outH);
109             dstInfo->widthScale  = (float)(inW) / (float)(outW);
110             break;
111         }
112         case CoordinateTransformationMode_TensorflowHalfPixels:
113         {
114             dstInfo->heightScale = (float)(inH) / (float)(outH);
115             dstInfo->widthScale  = (float)(inW) / (float)(outW);
116             dstInfo->widthOffset = 0.5f * dstInfo->widthScale;
117             dstInfo->heightOffset = 0.5f * dstInfo->heightScale;
118             break;
119         }
120         case CoordinateTransformationMode_TensorflowCropAndResize:
121         {
122             //FIXME: Not support now
123             MNN_ERROR("Don't support CoordinateTransformationMode_TensorflowCropAndResize currently\n");
124             break;
125         }
126         default:
127             break;
128     }
129 }
makeInterp(flatbuffers::FlatBufferBuilder & builder,const InterpInfo * info,int resizeType,const Op * op)130 static flatbuffers::Offset<Op> makeInterp(flatbuffers::FlatBufferBuilder& builder, const InterpInfo* info, int resizeType, const Op* op) {
131     flatbuffers::Offset<flatbuffers::String> temp;
132     if (nullptr != op->name()) {
133         temp = builder.CreateString(op->name()->str());
134     }
135     InterpBuilder intpB(builder);
136     intpB.add_resizeType(resizeType);
137     intpB.add_widthScale(info->widthScale);
138     intpB.add_heightScale(info->heightScale);
139     intpB.add_heightOffset(info->heightOffset);
140     intpB.add_widthOffset(info->widthOffset);
141     auto offsetInterp = intpB.Finish().Union();
142     OpBuilder opB(builder);
143     opB.add_type(OpType_Interp);
144     opB.add_main(offsetInterp);
145     opB.add_main_type(OpParameter_Interp);
146     if (nullptr != op->name()) {
147         opB.add_name(temp);
148     }
149     return opB.Finish();
150 }
151 
152 class GeometryImageOp : public GeometryComputer {
153 public:
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const154     virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
155                            Context& context, CommandBuffer& res) const override {
156         auto newOutputs   = outputs;
157         auto newInputs    = inputs;
158         auto originOutput = outputs[0];
159         auto output       = originOutput;
160         auto inputDes     = TensorUtils::getDescribe(newInputs[0]);
161         auto format       = inputDes->dimensionFormat;
162         if (MNN_DATA_FORMAT_NC4HW4 != format) {
163             std::shared_ptr<Tensor> newInput(new Tensor(newInputs[0], Tensor::CAFFE_C4, false));
164             ConvertUtils::compute(newInputs[0], newInput.get(), res);
165             newInputs[0] = newInput.get();
166             res.extras.emplace_back(std::move(newInput));
167             std::shared_ptr<Tensor> newOutput(new Tensor(originOutput, Tensor::CAFFE_C4, false));
168             output        = newOutput.get();
169             newOutputs[0] = output;
170             res.extras.emplace_back(newOutput);
171         }
172         if (OpType_Resize == op->type()) {
173             // Turn resize to interp
174             InterpInfo info;
175             info.widthScale = (float)inputs[0]->width() / (float)outputs[0]->width();
176             info.heightScale = (float)inputs[0]->height() / (float)outputs[0]->height();
177             flatbuffers::FlatBufferBuilder builder;
178             builder.Finish(makeInterp(builder, &info, 2, op));
179             res.command.emplace_back(GeometryComputerUtils::makeCommand(builder, {newInputs[0]}, newOutputs));
180         }
181         else if (OpType_Interp == op->type()) {
182             // Compute cord transform for interp
183             auto resize                           = op->main_as_Interp();
184             auto inW = inputs[0]->width();
185             auto inH = inputs[0]->height();
186             auto outW = outputs[0]->width();
187             auto outH = outputs[0]->height();
188             InterpInfo info;
189             _ConverterInterp(resize, &info, inW, inH, outW, outH);
190             flatbuffers::FlatBufferBuilder builder;
191             builder.Finish(makeInterp(builder, &info, resize->resizeType(), op));
192             res.command.emplace_back(GeometryComputerUtils::makeCommand(builder, {newInputs[0]}, newOutputs));
193         } else {
194             Command cmd;
195             cmd.op      = op;
196             cmd.inputs  = std::move(newInputs);
197             cmd.outputs = std::move(newOutputs);
198             res.command.emplace_back(std::move(cmd));
199         }
200         if (originOutput != output) {
201             ConvertUtils::compute(output, originOutput, res);
202         }
203         return true;
204     }
205 };
206 
_create()207 static void _create() {
208     std::shared_ptr<GeometryComputer> comp(new GeometryImageOp);
209     GeometryComputer::registerGeometryComputer(
210         comp, {
211         OpType_ConvInt8,
212         OpType_DepthwiseConvInt8,
213         OpType_ConvolutionDepthwise,
214         OpType_DeconvolutionDepthwise,
215         OpType_Pooling,
216         OpType_Interp,
217         OpType_Resize,
218         OpType_Int8ToFloat,
219         OpType_FloatToInt8
220     });
221 }
222 
223 REGISTER_GEOMETRY(GeometryImageOp, _create);
224 
225 } // namespace MNN
226