1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file spatial_transformer.cc
22  * \brief
23  * \author Wei Wu
24 */
25 
26 #include "./spatial_transformer-inl.h"
27 
28 namespace mshadow {
29 template<typename DType>
between(const DType value,const DType lowerBound,const DType upperBound)30 static MSHADOW_CINLINE bool between(const DType value,
31                                     const DType lowerBound,
32                                     const DType upperBound) {
33   return value >= lowerBound && value <= upperBound;
34 }
35 
36 template<typename DType>
BilinearSamplingForward(const Tensor<cpu,4,DType> & output,const Tensor<cpu,4,DType> & input,const Tensor<cpu,3,DType> grid_src)37 inline void BilinearSamplingForward(const Tensor<cpu, 4, DType> &output,
38                                     const Tensor<cpu, 4, DType> &input,
39                                     const Tensor<cpu, 3, DType> grid_src) {
40   DType *out = output.dptr_;
41   const DType *data = input.dptr_;
42   const DType *grid = grid_src.dptr_;
43   const index_t o_n = output.size(0), o_c = output.size(1),
44     o_h = output.size(2), o_w = output.size(3);
45   const index_t i_c = input.size(1), i_h = input.size(2), i_w = input.size(3);
46   for (index_t n = 0; n < static_cast<index_t>(o_n); ++n) {
47     for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
48       for (index_t h = 0; h < static_cast<index_t>(o_h); ++h) {
49         for (index_t w = 0; w < static_cast<index_t>(o_w); ++w) {
50           const index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
51           const index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
52           const DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
53           const DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
54           const auto top_left_y = static_cast<index_t>(std::floor(y_real));
55           const auto top_left_x = static_cast<index_t>(std::floor(x_real));
56           const DType top_left_y_w = 1.0 - (y_real - top_left_y);
57           const DType top_left_x_w = 1.0 - (x_real - top_left_x);
58           const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w +
59                                  top_left_y * i_w + top_left_x;
60           DType top_left_v = 0;
61           DType top_right_v = 0;
62           DType bottom_left_v = 0;
63           DType bottom_right_v = 0;
64           index_t lower_bound = 0;
65           if (between(top_left_x, lower_bound, i_w-1) &&
66               between(top_left_y, lower_bound, i_h-1))
67             top_left_v = *(data + data_index);
68           if (between(top_left_x + 1, lower_bound, i_w-1) &&
69               between(top_left_y, lower_bound, i_h-1))
70             top_right_v = *(data + data_index + 1);
71           if (between(top_left_x, lower_bound, i_w-1) &&
72               between(top_left_y + 1, lower_bound, i_h-1))
73             bottom_left_v = *(data + data_index + i_w);
74           if (between(top_left_x+1, lower_bound, i_w-1) &&
75               between(top_left_y + 1, lower_bound, i_h-1))
76             bottom_right_v = *(data + data_index + i_w + 1);
77           *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w +
78                              top_right_v * top_left_y_w * (1.0 - top_left_x_w) +
79                              bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w +
80                              bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w);
81         }
82       }
83     }
84   }
85 }
86 
87 template<typename DType>
BilinearSamplingBackward(const Tensor<cpu,4,DType> & input_grad,const Tensor<cpu,3,DType> & grid_src_data,const Tensor<cpu,4,DType> & output_grad,const Tensor<cpu,4,DType> & input_data)88 inline void BilinearSamplingBackward(const Tensor<cpu, 4, DType> &input_grad,
89                                      const Tensor<cpu, 3, DType> &grid_src_data,
90                                      const Tensor<cpu, 4, DType> &output_grad,
91                                      const Tensor<cpu, 4, DType> &input_data) {
92   DType *g_input = input_grad.dptr_;
93   DType *grid_src = grid_src_data.dptr_;
94   const DType *grad = output_grad.dptr_;
95   const DType *data = input_data.dptr_;
96   const index_t o_n = output_grad.size(0), o_c = output_grad.size(1),
97     o_h = output_grad.size(2), o_w = output_grad.size(3);
98   const index_t i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3);
99   for (index_t n = 0; n < static_cast<index_t>(o_n); ++n) {
100      for (index_t h = 0; h < static_cast<index_t>(o_h); ++h) {
101         for (index_t w = 0; w < static_cast<index_t>(o_w); ++w) {
102           DType top_left_y_gw = 0.0;
103           DType top_left_x_gw = 0.0;
104           const index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
105           const DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2;
106           const DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
107           const auto top_left_y = static_cast<index_t>(std::floor(y_real));
108           const auto top_left_x = static_cast<index_t>(std::floor(x_real));
109           const DType top_left_y_w = 1.0 - (y_real - top_left_y);
110           const DType top_left_x_w = 1.0 - (x_real - top_left_x);
111           for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
112             index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
113             const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w +
114                                    top_left_y * i_w + top_left_x;
115             // calc 4 vertex value in input data
116             DType top_left_v = 0;
117             DType top_right_v = 0;
118             DType bottom_left_v = 0;
119             DType bottom_right_v = 0;
120             index_t lower_bound = 0;
121             if (between(top_left_x, lower_bound, i_w-1) &&
122                 between(top_left_y, lower_bound, i_h-1)) {
123               *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
124               top_left_v = *(data + data_index);
125             }
126             if (between(top_left_x+1, lower_bound, i_w-1) &&
127                 between(top_left_y, lower_bound, i_h-1)) {
128               *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
129                                              * (1.0 - top_left_x_w);
130               top_right_v = *(data + data_index + 1);
131             }
132             if (between(top_left_x, lower_bound, i_w-1) &&
133                 between(top_left_y+1, lower_bound, i_h-1)) {
134               *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
135                                               * top_left_x_w;
136               bottom_left_v = *(data + data_index + i_w);
137             }
138             if (between(top_left_x+1, lower_bound, i_w-1) &&
139                 between(top_left_y+1, lower_bound, i_h-1)) {
140               *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
141                                                   * (1.0 - top_left_x_w);
142               bottom_right_v = *(data + data_index + i_w + 1);
143             }
144             // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
145             top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
146                              (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
147                              * top_left_x_w);
148             top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
149                              (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
150                              * top_left_y_w);
151           }
152           // calc grid_src grad
153           *(grid_src + grid_src_index + o_h * o_w) = top_left_y_gw * (i_h - 1) / 2;
154           *(grid_src + grid_src_index) = top_left_x_gw * (i_w - 1) / 2;
155         }
156       }
157     }
158   }
159 
160 }  // namespace mshadow
161 
162 namespace mxnet {
163 namespace op {
164 template<>
CreateOp(SpatialTransformerParam param,int dtype)165 Operator* CreateOp<cpu>(SpatialTransformerParam param, int dtype) {
166   Operator *op = nullptr;
167   MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
168     op = new SpatialTransformerOp<cpu, DType>(param);
169   })
170   return op;
171 }
172 
CreateOperatorEx(Context ctx,mxnet::ShapeVector * in_shape,std::vector<int> * in_type) const173 Operator *SpatialTransformerProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
174                                      std::vector<int> *in_type) const {
175   DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
176 }
177 
178 DMLC_REGISTER_PARAMETER(SpatialTransformerParam);
179 
180 MXNET_REGISTER_OP_PROPERTY(SpatialTransformer, SpatialTransformerProp)
181 .add_argument("data", "NDArray-or-Symbol",
182               "Input data to the SpatialTransformerOp.")
183 .add_argument("loc", "NDArray-or-Symbol",
184               "localisation net, the output dim should be 6 when transform_type "
185               "is affine. You shold initialize the weight and bias with identity tranform.")
186 .add_arguments(SpatialTransformerParam::__FIELDS__())
187 .describe("Applies a spatial transformer to input feature map.");
188 
189 }  // namespace op
190 }  // namespace mxnet
191