1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 /*!
20 * \file resize-inl.h
21 * \brief image resize operator using opencv and only support bilinear resize
22 * \author Jake Lee
23 */
24 #ifndef MXNET_OPERATOR_IMAGE_RESIZE_INL_H_
25 #define MXNET_OPERATOR_IMAGE_RESIZE_INL_H_
26
27 #include <mxnet/base.h>
28 #include <vector>
29
30 #include "../mxnet_op.h"
31 #include "../operator_common.h"
32 #include "image_utils.h"
33
34 #if MXNET_USE_OPENCV
35 #include <opencv2/opencv.hpp>
36 #endif // MXNET_USE_OPENCV
37
38 namespace mxnet {
39 namespace op {
40 namespace image {
41
42 using namespace mshadow;
43
44 #if MXNET_USE_CUDA
45 template<typename DType, typename T, typename Acctype>
46 void ResizeImplCUDA(Stream<gpu> *s,
47 const T input,
48 const T output);
49 #endif // MXNET_USE_CUDA
50
51 struct ResizeParam : public dmlc::Parameter<ResizeParam> {
52 mxnet::Tuple<int> size;
53 bool keep_ratio;
54 int interp;
DMLC_DECLARE_PARAMETERResizeParam55 DMLC_DECLARE_PARAMETER(ResizeParam) {
56 DMLC_DECLARE_FIELD(size)
57 .set_default(mxnet::Tuple<int>())
58 .describe("Size of new image. Could be (width, height) or (size)");
59 DMLC_DECLARE_FIELD(keep_ratio)
60 .describe("Whether to resize the short edge or both edges to `size`, "
61 "if size is give as an integer.")
62 .set_default(false);
63 DMLC_DECLARE_FIELD(interp)
64 .set_default(1)
65 .describe("Interpolation method for resizing. By default uses bilinear interpolation"
66 "Options are INTER_NEAREST - a nearest-neighbor interpolation"
67 "INTER_LINEAR - a bilinear interpolation"
68 "INTER_AREA - resampling using pixel area relation"
69 "INTER_CUBIC - a bicubic interpolation over 4x4 pixel neighborhood"
70 "INTER_LANCZOS4 - a Lanczos interpolation over 8x8 pixel neighborhood"
71 "Note that the GPU version only support bilinear interpolation(1)");
72 }
73 };
74 // handle the keep ratio param
GetHeightAndWidth(int data_h,int data_w,const ResizeParam & param)75 inline SizeParam GetHeightAndWidth(int data_h,
76 int data_w,
77 const ResizeParam& param) {
78 CHECK((param.size.ndim() == 1) || (param.size.ndim() == 2))
79 << "Input size dimension must be 1 or 2, but got "
80 << param.size.ndim();
81 int resized_h;
82 int resized_w;
83 if (param.size.ndim() == 1) {
84 CHECK_GT(param.size[0], 0)
85 << "Input size should be greater than 0, but got "
86 << param.size[0];
87 if (!param.keep_ratio) {
88 resized_h = param.size[0];
89 resized_w = param.size[0];
90 } else {
91 if (data_h > data_w) {
92 resized_w = param.size[0];
93 resized_h = static_cast<int>(data_h * resized_w / data_w);
94 } else {
95 resized_h = param.size[0];
96 resized_w = static_cast<int>(data_w * resized_h / data_h);
97 }
98 }
99 } else {
100 CHECK_GT(param.size[0], 0)
101 << "Input width should be greater than 0, but got "
102 << param.size[0];
103 CHECK_GT(param.size[1], 0)
104 << "Input height should be greater than 0, but got "
105 << param.size[1];
106 resized_h = param.size[1];
107 resized_w = param.size[0];
108 }
109 return SizeParam(resized_h, resized_w);
110 }
111
ResizeShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)112 inline bool ResizeShape(const nnvm::NodeAttrs& attrs,
113 mxnet::ShapeVector *in_attrs,
114 mxnet::ShapeVector *out_attrs) {
115 // input attrs should only be (h, w, c) or (n, h, w, c)
116 CHECK((in_attrs->at(0).ndim() == 3U) || (in_attrs->at(0).ndim() == 4U))
117 << "Input image dimension should be 3 or 4 but got "
118 << in_attrs->at(0).ndim();
119 const auto& ishape = (*in_attrs)[0];
120 const ResizeParam& param = nnvm::get<ResizeParam>(attrs.parsed);
121 SizeParam size;
122 if (ishape.ndim() == 3) {
123 size = GetHeightAndWidth(ishape[H], ishape[W], param);
124 SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({size.height, size.width, ishape[C]}));
125 } else {
126 size = GetHeightAndWidth(ishape[kH], ishape[kW], param);
127 SHAPE_ASSIGN_CHECK(*out_attrs, 0,
128 mxnet::TShape({ishape[N], size.height, size.width, ishape[kC]}));
129 }
130 return true;
131 }
132
133 inline void ResizeImpl(const std::vector<TBlob> &inputs,
134 const std::vector<TBlob> &outputs,
135 const int height,
136 const int width,
137 const int interp,
138 const int input_index = 0,
139 const int output_index = 0) {
140 #if MXNET_USE_OPENCV
141 CHECK_NE(inputs[0].type_flag_, mshadow::kFloat16) << "opencv image mat doesn't support fp16";
142 CHECK((inputs[0].type_flag_ != mshadow::kInt32) || (inputs[0].type_flag_ != mshadow::kInt64))
143 << "opencv resize doesn't support int32, int64";
144 // mapping to opencv matrix element type according to channel
145 const int DTYPE[] = {CV_32F, CV_64F, -1, CV_8U, CV_32S};
146 if (inputs[0].ndim() == 3) {
147 const int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[C]);
148 cv::Mat buf(inputs[0].shape_[H], inputs[0].shape_[W], cv_type, inputs[0].dptr_);
149 cv::Mat dst(outputs[0].shape_[H], outputs[0].shape_[W], cv_type, outputs[0].dptr_);
150 cv::resize(buf, dst, cv::Size(width, height), 0, 0, interp);
151 CHECK(!dst.empty());
152 CHECK_EQ(static_cast<void*>(dst.ptr()), outputs[0].dptr_);
153 } else {
154 const int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[kC]);
155 MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
156 cv::Mat buf(inputs[0].shape_[kH], inputs[0].shape_[kW], cv_type,
157 inputs[0].dptr<DType>() + input_index);
158 cv::Mat dst(outputs[0].shape_[kH], outputs[0].shape_[kW], cv_type,
159 outputs[0].dptr<DType>() + output_index);
160 cv::resize(buf, dst, cv::Size(width, height), 0, 0, interp);
161 CHECK(!dst.empty());
162 CHECK_EQ(static_cast<void*>(dst.ptr()), outputs[0].dptr<DType>() + output_index);
163 });
164 }
165 #else
166 LOG(FATAL) << "Build with USE_OPENCV=1 for image resize operator.";
167 #endif // MXNET_USE_OPENCV
168 }
169
170 template <typename xpu>
Resize(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)171 inline void Resize(const nnvm::NodeAttrs &attrs,
172 const OpContext &ctx,
173 const std::vector<TBlob> &inputs,
174 const std::vector<OpReqType> &req,
175 const std::vector<TBlob> &outputs) {
176 CHECK_EQ(outputs.size(), 1U);
177 const ResizeParam& param = nnvm::get<ResizeParam>(attrs.parsed);
178 SizeParam size;
179 if (std::is_same<xpu, gpu>::value) {
180 #if MXNET_USE_CUDA
181 CHECK(param.interp == 1) << "interp should be 1 for using Resize on GPU.";
182 mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
183 MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
184 if (inputs[0].ndim() == 3) {
185 Tensor<gpu, 3, DType> input = inputs[0].get<gpu, 3, DType>(s);
186 Tensor<gpu, 3, DType> output = outputs[0].get<gpu, 3, DType>(s);
187 ResizeImplCUDA<DType, Tensor<gpu, 3, DType>, float>(s, input, output);
188 } else {
189 Tensor<gpu, 4, DType> input = inputs[0].get<gpu, 4, DType>(s);
190 Tensor<gpu, 4, DType> output = outputs[0].get<gpu, 4, DType>(s);
191 ResizeImplCUDA<DType, Tensor<gpu, 4, DType>, float>(s, input, output);
192 }
193 });
194 #endif // MXNET_USE_CUDA
195 } else if (inputs[0].ndim() == 3) {
196 size = GetHeightAndWidth(inputs[0].shape_[H], inputs[0].shape_[W], param);
197 ResizeImpl(inputs, outputs, size.height, size.width, param.interp);
198 } else {
199 size = GetHeightAndWidth(inputs[0].shape_[kH], inputs[0].shape_[kW], param);
200 const auto batch_size = inputs[0].shape_[N];
201 const auto input_step = inputs[0].shape_[kH] * inputs[0].shape_[kW] * inputs[0].shape_[kC];
202 const auto output_step = size.height * size.width * inputs[0].shape_[kC];
203 #pragma omp parallel for
204 for (auto i = 0; i < batch_size; ++i) {
205 ResizeImpl(inputs, outputs, size.height, size.width,
206 param.interp, i * input_step, i * output_step);
207 }
208 }
209 }
210
211 } // namespace image
212 } // namespace op
213 } // namespace mxnet
214
215 #endif // MXNET_OPERATOR_IMAGE_RESIZE_INL_H_
216