/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file spatial_transformer.cc * \brief * \author Wei Wu */ #include "./spatial_transformer-inl.h" namespace mshadow { template static MSHADOW_CINLINE bool between(const DType value, const DType lowerBound, const DType upperBound) { return value >= lowerBound && value <= upperBound; } template inline void BilinearSamplingForward(const Tensor &output, const Tensor &input, const Tensor grid_src) { DType *out = output.dptr_; const DType *data = input.dptr_; const DType *grid = grid_src.dptr_; const index_t o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3); const index_t i_c = input.size(1), i_h = input.size(2), i_w = input.size(3); for (index_t n = 0; n < static_cast(o_n); ++n) { for (index_t c = 0; c < static_cast(o_c); ++c) { for (index_t h = 0; h < static_cast(o_h); ++h) { for (index_t w = 0; w < static_cast(o_w); ++w) { const index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; const index_t grid_index = n * o_h * o_w * 2 + h * o_w + w; const DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2; const DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2; const auto top_left_y = static_cast(std::floor(y_real)); const auto top_left_x = static_cast(std::floor(x_real)); const DType top_left_y_w = 1.0 - (y_real - top_left_y); const DType top_left_x_w = 1.0 - (x_real - top_left_x); const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; DType top_left_v = 0; DType top_right_v = 0; DType bottom_left_v = 0; DType bottom_right_v = 0; index_t lower_bound = 0; if (between(top_left_x, lower_bound, i_w-1) && between(top_left_y, lower_bound, i_h-1)) top_left_v = *(data + data_index); if (between(top_left_x + 1, lower_bound, i_w-1) && between(top_left_y, lower_bound, i_h-1)) top_right_v = *(data + data_index + 1); if (between(top_left_x, lower_bound, i_w-1) && between(top_left_y + 1, lower_bound, i_h-1)) bottom_left_v = *(data + data_index + i_w); if (between(top_left_x+1, lower_bound, i_w-1) && between(top_left_y + 1, lower_bound, i_h-1)) bottom_right_v = *(data + data_index + i_w + 1); *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w + top_right_v * top_left_y_w * (1.0 - top_left_x_w) + bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w + bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w); } } } } } template inline void BilinearSamplingBackward(const Tensor &input_grad, const Tensor &grid_src_data, const Tensor &output_grad, const Tensor &input_data) { DType *g_input = input_grad.dptr_; DType *grid_src = grid_src_data.dptr_; const DType *grad = output_grad.dptr_; const DType *data = input_data.dptr_; const index_t o_n = output_grad.size(0), o_c = output_grad.size(1), o_h = output_grad.size(2), o_w = output_grad.size(3); const index_t i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3); for (index_t n = 0; n < static_cast(o_n); ++n) { for (index_t h = 0; h < static_cast(o_h); ++h) { for (index_t w = 0; w < static_cast(o_w); ++w) { DType top_left_y_gw = 0.0; DType top_left_x_gw = 0.0; const index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w; const DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2; const DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2; const auto top_left_y = static_cast(std::floor(y_real)); const auto top_left_x = static_cast(std::floor(x_real)); const DType top_left_y_w = 1.0 - (y_real - top_left_y); const DType top_left_x_w = 1.0 - (x_real - top_left_x); for (index_t c = 0; c < static_cast(o_c); ++c) { index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; // calc 4 vertex value in input data DType top_left_v = 0; DType top_right_v = 0; DType bottom_left_v = 0; DType bottom_right_v = 0; index_t lower_bound = 0; if (between(top_left_x, lower_bound, i_w-1) && between(top_left_y, lower_bound, i_h-1)) { *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w; top_left_v = *(data + data_index); } if (between(top_left_x+1, lower_bound, i_w-1) && between(top_left_y, lower_bound, i_h-1)) { *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w); top_right_v = *(data + data_index + 1); } if (between(top_left_x, lower_bound, i_w-1) && between(top_left_y+1, lower_bound, i_h-1)) { *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w; bottom_left_v = *(data + data_index + i_w); } if (between(top_left_x+1, lower_bound, i_w-1) && between(top_left_y+1, lower_bound, i_h-1)) { *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w); bottom_right_v = *(data + data_index + i_w + 1); } // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) * top_left_x_w); top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) * top_left_y_w); } // calc grid_src grad *(grid_src + grid_src_index + o_h * o_w) = top_left_y_gw * (i_h - 1) / 2; *(grid_src + grid_src_index) = top_left_x_gw * (i_w - 1) / 2; } } } } } // namespace mshadow namespace mxnet { namespace op { template<> Operator* CreateOp(SpatialTransformerParam param, int dtype) { Operator *op = nullptr; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new SpatialTransformerOp(param); }) return op; } Operator *SpatialTransformerProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, std::vector *in_type) const { DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } DMLC_REGISTER_PARAMETER(SpatialTransformerParam); MXNET_REGISTER_OP_PROPERTY(SpatialTransformer, SpatialTransformerProp) .add_argument("data", "NDArray-or-Symbol", "Input data to the SpatialTransformerOp.") .add_argument("loc", "NDArray-or-Symbol", "localisation net, the output dim should be 6 when transform_type " "is affine. You shold initialize the weight and bias with identity tranform.") .add_arguments(SpatialTransformerParam::__FIELDS__()) .describe("Applies a spatial transformer to input feature map."); } // namespace op } // namespace mxnet