1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 /*!
20 * \file bilinear_resize-inl.h
21 * \brief bilinear resize operator
22 * \author Hang Zhang
23 */
24 #ifndef MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_
25 #define MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_
26
27 #include <dmlc/logging.h>
28 #include <dmlc/parameter.h>
29 #include <mxnet/operator.h>
30 #include <mxnet/ndarray.h>
31 #include <map>
32 #include <vector>
33 #include <string>
34 #include <utility>
35 /* contrib
36 #include "../ndarray/ndarray_function.h"
37 #include "./operator_common.h"
38 #include "./mxnet_op.h"
39 #include "./mshadow_op.h"
40 */
41 #include "../../ndarray/ndarray_function.h"
42 #include "../operator_common.h"
43 #include "../mxnet_op.h"
44 #include "../mshadow_op.h"
45
46 namespace bilinear_resize {
47 enum BilinearResizeOpMode{simple, odd_scale, like, to_even_down, to_even_up, to_odd_down,
48 to_odd_up};
49 } // namespace bilinear_resize
50
51
52 namespace mxnet {
53 namespace op {
54
55 struct BilinearSampleParam : public dmlc::Parameter<BilinearSampleParam> {
56 int height;
57 int width;
58 dmlc::optional<float> scale_height;
59 dmlc::optional<float> scale_width;
60 int mode;
61 bool align_corners;
DMLC_DECLARE_PARAMETERBilinearSampleParam62 DMLC_DECLARE_PARAMETER(BilinearSampleParam) {
63 DMLC_DECLARE_FIELD(height).set_default(1).set_lower_bound(1)
64 .describe("output height (required, but ignored if scale_height is defined or mode is not "
65 "\"size\")");
66 DMLC_DECLARE_FIELD(width).set_default(1).set_lower_bound(1)
67 .describe("output width (required, but ignored if scale_width is defined or mode is not "
68 "\"size\")");
69 DMLC_DECLARE_FIELD(scale_height).set_default(dmlc::optional<float>())
70 .describe("sampling scale of the height (optional, used in modes \"scale\" and \"odd_scale\")");
71 DMLC_DECLARE_FIELD(scale_width).set_default(dmlc::optional<float>())
72 .describe("sampling scale of the width (optional, used in modes \"scale\" and \"odd_scale\")");
73 DMLC_DECLARE_FIELD(mode)
74 .add_enum("size", bilinear_resize::simple)
75 .add_enum("odd_scale", bilinear_resize::odd_scale)
76 .add_enum("like", bilinear_resize::like)
77 .add_enum("to_even_down", bilinear_resize::to_even_down)
78 .add_enum("to_even_up", bilinear_resize::to_even_up)
79 .add_enum("to_odd_down", bilinear_resize::to_odd_down)
80 .add_enum("to_odd_up", bilinear_resize::to_odd_up)
81 .set_default(bilinear_resize::simple)
82 .describe("resizing mode. \"simple\" - output height equals parameter \"height\" if "
83 "\"scale_height\" parameter is not defined or input height multiplied by "
84 "\"scale_height\" otherwise. Same for width;"
85 "\"odd_scale\" - if original height or width is odd, then result height is "
86 "calculated like result_h = (original_h - 1) * scale + 1; "
87 "for scale > 1 the result shape would be like if we did deconvolution with kernel "
88 "= (1, 1) and stride = (height_scale, width_scale); and for scale < 1 shape "
89 "would be like we did convolution with kernel = (1, 1) and "
90 "stride = (int(1 / height_scale), int( 1/ width_scale);"
91 "\"like\" - resize first input to the height and width of second input; "
92 "\"to_even_down\" - resize input to nearest lower even height and width "
93 "(if original height is odd then result height = original height - 1);"
94 "\"to_even_up\" - resize input to nearest bigger even height and width "
95 "(if original height is odd then result height = original height + 1);"
96 "\"to_odd_down\" - resize input to nearest odd height and width "
97 "(if original height is odd then result height = original height - 1);"
98 "\"to_odd_up\" - resize input to nearest odd height and width "
99 "(if original height is odd then result height = original height + 1);");
100 DMLC_DECLARE_FIELD(align_corners).set_default(true)
101 .describe("With align_corners = True, the interpolating doesn't proportionally align the"
102 "output and input pixels, and thus the output values can depend on the input size.");
103 }
104 };
105
106 template <typename DType>
area_pixel_compute_scale(int64_t input_size,int64_t output_size,bool align_corners)107 static inline DType area_pixel_compute_scale(
108 int64_t input_size,
109 int64_t output_size,
110 bool align_corners) {
111 /* We view each pixel as an area, idx + 0.5 as its center index.
112 * Here is an example formula in 1D case.
113 * if align_corners: center of two corner pixel areas are preserved,
114 * (0.5, 0.5) -> (0.5, 0.5),
115 * (input_size - 0.5, 0.5) -> (output_size - 0.5)
116 * scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5)
117 * src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
118 * if not align_corners: the whole range is scaled accordingly
119 * scale = input_size / output_size
120 * src_idx + 0.5 = scale * (dst_index + 0.5)
121 */
122 if (output_size > 1) {
123 return align_corners
124 ? static_cast<DType>(input_size - 1) / (output_size - 1)
125 : static_cast<DType>(input_size) / output_size;
126 } else {
127 return DType(0);
128 }
129 }
130
131 template <typename DType>
area_pixel_compute_source_index(DType scale,int64_t dst_index,bool align_corners,bool cubic)132 static inline DType area_pixel_compute_source_index(
133 DType scale,
134 int64_t dst_index,
135 bool align_corners,
136 bool cubic) {
137 if (align_corners) {
138 return scale * dst_index;
139 } else {
140 DType src_idx = scale * (dst_index + 0.5) - 0.5;
141 // [Note] Follow Opencv resize logic:
142 // We allow negative src_idx here and later will use
143 // dx = src_idx - floorf(src_idx)
144 // to compute the "distance"(which affects weights).
145 // For linear modes, weight distribution doesn't matter
146 // for negative indices as they use 2 pixels to interpolate.
147 // For example, [-1, 0], they both use pixel 0 value so it
148 // doesn't affect if we bound the src_idx to 0 or not.
149 // TODO(chinakook): Our current linear mode impls use unbound indices
150 // where we should and then remove this cubic flag.
151 // This matters in cubic mode, as we might need [-1, 0, 1, 2]
152 // to interpolate and the weights can be affected.
153 return (!cubic && src_idx < 0) ? DType(0) : src_idx;
154 }
155 }
156
IsWriting(const OpReqType ort)157 static inline bool IsWriting(const OpReqType ort) {
158 return ort == kWriteTo || ort == kWriteInplace;
159 }
160
161 template<typename xpu, typename DType, typename AccReal>
162 void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<cpu> *s,
163 const std::vector<TBlob> &input,
164 const std::vector<TBlob> &output,
165 bool align_corners);
166
167 template<typename xpu, typename DType, typename AccReal>
168 void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
169 const std::vector<TBlob> &input,
170 const std::vector<TBlob> &output,
171 bool modeLike,
172 bool align_corners);
173
174 #if MXNET_USE_CUDA
175 template<typename xpu, typename DType, typename AccReal>
176 void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<gpu> *s,
177 const std::vector<TBlob> &input,
178 const std::vector<TBlob> &output,
179 bool align_corners);
180
181 template<typename xpu, typename DType, typename AccReal>
182 void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<gpu> *s,
183 const std::vector<TBlob> &input,
184 const std::vector<TBlob> &output,
185 bool modeLike,
186 bool align_corners);
187 #endif // MXNET_USE_CUDA
188
189 template <typename xpu>
BilinearSampleOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)190 inline void BilinearSampleOpForward(const nnvm::NodeAttrs& attrs,
191 const OpContext &ctx,
192 const std::vector<TBlob> &inputs,
193 const std::vector<OpReqType> &req,
194 const std::vector<TBlob> &outputs) {
195 const BilinearSampleParam& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
196 size_t expected = param.mode == bilinear_resize::like ? 2 : 1;
197 CHECK_EQ(inputs.size(), expected);
198 CHECK_EQ(outputs.size(), 1U);
199 CHECK_EQ(inputs[0].CheckContiguous(), true);
200 if (expected == 2) {
201 CHECK_EQ(inputs[1].CheckContiguous(), true);
202 }
203 CHECK_EQ(outputs[0].CheckContiguous(), true);
204
205 bool align_corners = param.align_corners;
206 mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
207 MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
208 SpatialUpSamplingBilinearUpdateOutput<xpu, DType, AccReal>(s, inputs, outputs, align_corners);
209 });
210 }
211
212
213 template <typename xpu>
BilinearSampleOpBackward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)214 inline void BilinearSampleOpBackward(const nnvm::NodeAttrs& attrs,
215 const OpContext &ctx,
216 const std::vector<TBlob> &inputs,
217 const std::vector<OpReqType> &req,
218 const std::vector<TBlob> &outputs) {
219 const BilinearSampleParam& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
220 CHECK_EQ(inputs.size(), 1U);
221 bool modeLike = param.mode == bilinear_resize::like;
222 bool align_corners = param.align_corners;
223 size_t expected = modeLike ? 2 : 1;
224 CHECK_EQ(outputs.size(), expected);
225 mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
226 if (IsWriting(req[0])) {
227 // zero grad before backwarding
228 MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
229 Fill<false>(s, outputs[0], kWriteTo, 0);
230 })
231 }
232 MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
233 SpatialUpSamplingBilinearUpdateGradInput<xpu, DType, AccReal>(s, inputs, outputs
234 , modeLike, align_corners);
235 });
236 }
237
238
BilinearSampleOpInferShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape)239 static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs,
240 mxnet::ShapeVector *in_shape,
241 mxnet::ShapeVector *out_shape) {
242 using namespace mshadow;
243 CHECK_EQ(out_shape->size(), 1U) << "Output:[data]";
244 const BilinearSampleParam& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
245 size_t expected = param.mode == bilinear_resize::like ? 2 : 1;
246 CHECK_EQ(in_shape->size(), expected);
247 mxnet::TShape dshape(in_shape->at(0));
248 if (mxnet::op::shape_is_none(dshape)) return false;
249 int16_t new_height = -1;
250 int16_t new_width = -1;
251 switch (param.mode) {
252 case bilinear_resize::simple:
253 {
254 if (param.scale_height.has_value()) {
255 new_height = static_cast<int>(param.scale_height.value() * in_shape->at(0)[2]);
256 } else {
257 new_height = param.height;
258 }
259 if (param.scale_height.has_value()) {
260 new_width = static_cast<int>(param.scale_width.value() * in_shape->at(0)[3]);
261 } else {
262 new_width = param.width;
263 }
264 break;
265 }
266 case bilinear_resize::odd_scale:
267 {
268 new_height = ((dshape[2] % 2) == 0) ? (int16_t) (dshape[2] * param.scale_height.value()) :
269 (int16_t) ((dshape[2] - 1) * param.scale_height.value()) + 1;
270 new_width = ((dshape[3] % 2) == 0) ? (int16_t) (dshape[3] * param.scale_width.value()) :
271 (int16_t) ((dshape[3] - 1) * param.scale_width.value()) + 1;
272 break;
273 }
274 case bilinear_resize::like:
275 {
276 TShape like_shape(in_shape->at(1));
277 if (dshape.ndim() == 0) return false;
278 new_height = like_shape[2];
279 new_width = like_shape[3];
280 break;
281 }
282 case bilinear_resize::to_even_down:
283 {
284 new_height = ((dshape[2] % 2) == 0) ? dshape[2] : dshape[2] - 1;
285 new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] - 1;
286 break;
287 }
288 case bilinear_resize::to_even_up:
289 {
290 new_height = ((dshape[2] % 2) == 0) ? dshape[2] : dshape[2] + 1;
291 new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] + 1;
292 break;
293 }
294 case bilinear_resize::to_odd_down:
295 {
296 new_height = ((dshape[2] % 2) == 1) ? dshape[2] : dshape[2] - 1;
297 new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] - 1;
298 break;
299 }
300 case bilinear_resize::to_odd_up:
301 {
302 new_height = ((dshape[2] % 2) == 1) ? dshape[2] : dshape[2] + 1;
303 new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] + 1;
304 break;
305 }
306 default:
307 {
308 LOG(FATAL) << "Invalid mode " << param.mode;
309 }
310 }
311
312 dshape[2] = new_height;
313 dshape[3] = new_width;
314
315 out_shape->clear();
316 out_shape->push_back(dshape);
317 return true;
318 }
319
320
BilinearSampleOpNumInputs(const NodeAttrs & attrs)321 inline uint16_t BilinearSampleOpNumInputs(const NodeAttrs& attrs) {
322 auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
323 if (param.mode == bilinear_resize::like) {
324 return 2;
325 } else {
326 return 1;
327 }
328 }
329
BilinearSampleOpNumBackwardOutputs(const NodeAttrs & attrs)330 inline uint16_t BilinearSampleOpNumBackwardOutputs(const NodeAttrs& attrs) {
331 auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
332 if (param.mode == bilinear_resize::like) {
333 return 2;
334 } else {
335 return 1;
336 }
337 }
338
BilinearSampleOpInputNames(const NodeAttrs & attrs)339 inline std::vector<std::string> BilinearSampleOpInputNames(const NodeAttrs& attrs) {
340 auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
341 if (param.mode == bilinear_resize::like) {
342 return std::vector<std::string>{"data", "like"};
343 } else {
344 return std::vector<std::string>{"data"};
345 }
346 }
347
348 } // namespace op
349 } // namespace mxnet
350
351 #endif // MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_
352