1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file spatial_transformer.cc
22 * \brief
23 * \author Wei Wu
24 */
25
26 #include "./spatial_transformer-inl.h"
27
28 namespace mshadow {
29 template<typename DType>
between(const DType value,const DType lowerBound,const DType upperBound)30 static MSHADOW_CINLINE bool between(const DType value,
31 const DType lowerBound,
32 const DType upperBound) {
33 return value >= lowerBound && value <= upperBound;
34 }
35
36 template<typename DType>
BilinearSamplingForward(const Tensor<cpu,4,DType> & output,const Tensor<cpu,4,DType> & input,const Tensor<cpu,3,DType> grid_src)37 inline void BilinearSamplingForward(const Tensor<cpu, 4, DType> &output,
38 const Tensor<cpu, 4, DType> &input,
39 const Tensor<cpu, 3, DType> grid_src) {
40 DType *out = output.dptr_;
41 const DType *data = input.dptr_;
42 const DType *grid = grid_src.dptr_;
43 const index_t o_n = output.size(0), o_c = output.size(1),
44 o_h = output.size(2), o_w = output.size(3);
45 const index_t i_c = input.size(1), i_h = input.size(2), i_w = input.size(3);
46 for (index_t n = 0; n < static_cast<index_t>(o_n); ++n) {
47 for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
48 for (index_t h = 0; h < static_cast<index_t>(o_h); ++h) {
49 for (index_t w = 0; w < static_cast<index_t>(o_w); ++w) {
50 const index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
51 const index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
52 const DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
53 const DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
54 const auto top_left_y = static_cast<index_t>(std::floor(y_real));
55 const auto top_left_x = static_cast<index_t>(std::floor(x_real));
56 const DType top_left_y_w = 1.0 - (y_real - top_left_y);
57 const DType top_left_x_w = 1.0 - (x_real - top_left_x);
58 const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w +
59 top_left_y * i_w + top_left_x;
60 DType top_left_v = 0;
61 DType top_right_v = 0;
62 DType bottom_left_v = 0;
63 DType bottom_right_v = 0;
64 index_t lower_bound = 0;
65 if (between(top_left_x, lower_bound, i_w-1) &&
66 between(top_left_y, lower_bound, i_h-1))
67 top_left_v = *(data + data_index);
68 if (between(top_left_x + 1, lower_bound, i_w-1) &&
69 between(top_left_y, lower_bound, i_h-1))
70 top_right_v = *(data + data_index + 1);
71 if (between(top_left_x, lower_bound, i_w-1) &&
72 between(top_left_y + 1, lower_bound, i_h-1))
73 bottom_left_v = *(data + data_index + i_w);
74 if (between(top_left_x+1, lower_bound, i_w-1) &&
75 between(top_left_y + 1, lower_bound, i_h-1))
76 bottom_right_v = *(data + data_index + i_w + 1);
77 *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w +
78 top_right_v * top_left_y_w * (1.0 - top_left_x_w) +
79 bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w +
80 bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w);
81 }
82 }
83 }
84 }
85 }
86
87 template<typename DType>
BilinearSamplingBackward(const Tensor<cpu,4,DType> & input_grad,const Tensor<cpu,3,DType> & grid_src_data,const Tensor<cpu,4,DType> & output_grad,const Tensor<cpu,4,DType> & input_data)88 inline void BilinearSamplingBackward(const Tensor<cpu, 4, DType> &input_grad,
89 const Tensor<cpu, 3, DType> &grid_src_data,
90 const Tensor<cpu, 4, DType> &output_grad,
91 const Tensor<cpu, 4, DType> &input_data) {
92 DType *g_input = input_grad.dptr_;
93 DType *grid_src = grid_src_data.dptr_;
94 const DType *grad = output_grad.dptr_;
95 const DType *data = input_data.dptr_;
96 const index_t o_n = output_grad.size(0), o_c = output_grad.size(1),
97 o_h = output_grad.size(2), o_w = output_grad.size(3);
98 const index_t i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3);
99 for (index_t n = 0; n < static_cast<index_t>(o_n); ++n) {
100 for (index_t h = 0; h < static_cast<index_t>(o_h); ++h) {
101 for (index_t w = 0; w < static_cast<index_t>(o_w); ++w) {
102 DType top_left_y_gw = 0.0;
103 DType top_left_x_gw = 0.0;
104 const index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
105 const DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2;
106 const DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
107 const auto top_left_y = static_cast<index_t>(std::floor(y_real));
108 const auto top_left_x = static_cast<index_t>(std::floor(x_real));
109 const DType top_left_y_w = 1.0 - (y_real - top_left_y);
110 const DType top_left_x_w = 1.0 - (x_real - top_left_x);
111 for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
112 index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
113 const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w +
114 top_left_y * i_w + top_left_x;
115 // calc 4 vertex value in input data
116 DType top_left_v = 0;
117 DType top_right_v = 0;
118 DType bottom_left_v = 0;
119 DType bottom_right_v = 0;
120 index_t lower_bound = 0;
121 if (between(top_left_x, lower_bound, i_w-1) &&
122 between(top_left_y, lower_bound, i_h-1)) {
123 *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
124 top_left_v = *(data + data_index);
125 }
126 if (between(top_left_x+1, lower_bound, i_w-1) &&
127 between(top_left_y, lower_bound, i_h-1)) {
128 *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
129 * (1.0 - top_left_x_w);
130 top_right_v = *(data + data_index + 1);
131 }
132 if (between(top_left_x, lower_bound, i_w-1) &&
133 between(top_left_y+1, lower_bound, i_h-1)) {
134 *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
135 * top_left_x_w;
136 bottom_left_v = *(data + data_index + i_w);
137 }
138 if (between(top_left_x+1, lower_bound, i_w-1) &&
139 between(top_left_y+1, lower_bound, i_h-1)) {
140 *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
141 * (1.0 - top_left_x_w);
142 bottom_right_v = *(data + data_index + i_w + 1);
143 }
144 // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
145 top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
146 (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
147 * top_left_x_w);
148 top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
149 (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
150 * top_left_y_w);
151 }
152 // calc grid_src grad
153 *(grid_src + grid_src_index + o_h * o_w) = top_left_y_gw * (i_h - 1) / 2;
154 *(grid_src + grid_src_index) = top_left_x_gw * (i_w - 1) / 2;
155 }
156 }
157 }
158 }
159
160 } // namespace mshadow
161
162 namespace mxnet {
163 namespace op {
164 template<>
CreateOp(SpatialTransformerParam param,int dtype)165 Operator* CreateOp<cpu>(SpatialTransformerParam param, int dtype) {
166 Operator *op = nullptr;
167 MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
168 op = new SpatialTransformerOp<cpu, DType>(param);
169 })
170 return op;
171 }
172
CreateOperatorEx(Context ctx,mxnet::ShapeVector * in_shape,std::vector<int> * in_type) const173 Operator *SpatialTransformerProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
174 std::vector<int> *in_type) const {
175 DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
176 }
177
178 DMLC_REGISTER_PARAMETER(SpatialTransformerParam);
179
180 MXNET_REGISTER_OP_PROPERTY(SpatialTransformer, SpatialTransformerProp)
181 .add_argument("data", "NDArray-or-Symbol",
182 "Input data to the SpatialTransformerOp.")
183 .add_argument("loc", "NDArray-or-Symbol",
184 "localisation net, the output dim should be 6 when transform_type "
185 "is affine. You shold initialize the weight and bias with identity tranform.")
186 .add_arguments(SpatialTransformerParam::__FIELDS__())
187 .describe("Applies a spatial transformer to input feature map.");
188
189 } // namespace op
190 } // namespace mxnet
191