1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file bilinear_Sampler-inl.h 22 * \brief 23 * \author Xu Dong 24 */ 25 #ifndef MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_ 26 #define MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_ 27 28 #include <dmlc/logging.h> 29 #include <dmlc/parameter.h> 30 #include <mxnet/operator.h> 31 #include <vector> 32 #include <map> 33 #include <string> 34 #include <utility> 35 #include "./operator_common.h" 36 37 namespace mxnet { 38 namespace op { 39 40 namespace bs { 41 enum BilinearSamplerOpInputs {kData, kGrid}; 42 enum BilinearSamplerOpOutputs {kOut, kTmp}; 43 } 44 45 struct BilinearSamplerParam : public dmlc::Parameter<BilinearSamplerParam> { 46 dmlc::optional<bool> cudnn_off; DMLC_DECLARE_PARAMETERBilinearSamplerParam47 DMLC_DECLARE_PARAMETER(BilinearSamplerParam) { 48 DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional<bool>()) 49 .describe("whether to turn cudnn off"); 50 } 51 }; 52 53 template<typename xpu, typename DType> 54 class BilinearSamplerOp : public Operator { 55 public: BilinearSamplerOp(BilinearSamplerParam p)56 explicit BilinearSamplerOp(BilinearSamplerParam p) { 57 this->param_ = p; 58 } 59 Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_args)60 virtual void Forward(const OpContext &ctx, 61 const std::vector<TBlob> &in_data, 62 const std::vector<OpReqType> &req, 63 const std::vector<TBlob> &out_data, 64 const std::vector<TBlob> &aux_args) { 65 using namespace mshadow; 66 using namespace mshadow::expr; 67 CHECK_EQ(req[bs::kOut], kWriteTo); 68 CHECK_EQ(in_data.size(), 2U); 69 Stream<xpu> *s = ctx.get_stream<xpu>(); 70 71 Tensor<xpu, 4, DType> data = in_data[bs::kData].get<xpu, 4, DType>(s); 72 Tensor<xpu, 4, DType> grid = in_data[bs::kGrid].get<xpu, 4, DType>(s); 73 Tensor<xpu, 4, DType> out = out_data[bs::kOut].get<xpu, 4, DType>(s); 74 75 BilinearSamplerForward(out, data, grid); 76 } 77 Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_args)78 virtual void Backward(const OpContext &ctx, 79 const std::vector<TBlob> &out_grad, 80 const std::vector<TBlob> &in_data, 81 const std::vector<TBlob> &out_data, 82 const std::vector<OpReqType> &req, 83 const std::vector<TBlob> &in_grad, 84 const std::vector<TBlob> &aux_args) { 85 using namespace mshadow; 86 using namespace mshadow::expr; 87 CHECK_EQ(in_data.size(), 2U); 88 CHECK_NE(req[bs::kData], kWriteInplace); 89 CHECK_NE(req[bs::kGrid], kWriteInplace); 90 Stream<xpu> *s = ctx.get_stream<xpu>(); 91 92 Tensor<xpu, 4, DType> data = in_data[bs::kData].get<xpu, 4, DType>(s); 93 Tensor<xpu, 4, DType> grid = in_data[bs::kGrid].get<xpu, 4, DType>(s); 94 Tensor<xpu, 4, DType> gdata = in_grad[bs::kData].get<xpu, 4, DType>(s); 95 Tensor<xpu, 4, DType> ggrid = in_grad[bs::kGrid].get<xpu, 4, DType>(s); 96 Tensor<xpu, 4, DType> grad = out_grad[bs::kOut].get<xpu, 4, DType>(s); 97 if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) { 98 return; 99 } else { 100 if (req[bs::kData] == kWriteTo) { 101 gdata = scalar<DType>(0.0f); 102 } 103 if (req[bs::kGrid] == kWriteTo) { 104 ggrid = scalar<DType>(0.0f); 105 } 106 BilinearSamplerBackward(gdata, ggrid, grad, data, grid, req[bs::kData], req[bs::kGrid]); 107 } 108 } 109 110 private: 111 BilinearSamplerParam param_; 112 }; // class BilinearSamplerOp 113 114 template<typename xpu> 115 Operator* CreateOp(BilinearSamplerParam param, int dtype); 116 117 #if DMLC_USE_CXX11 118 class BilinearSamplerProp : public OperatorProperty { 119 public: NumVisibleOutputs()120 int NumVisibleOutputs() const override { 121 return 1; 122 } 123 NumOutputs()124 int NumOutputs() const override { 125 return 2; 126 } 127 ListArguments()128 std::vector<std::string> ListArguments() const override { 129 return {"data", "grid"}; 130 } 131 ListOutputs()132 std::vector<std::string> ListOutputs() const override { 133 return {"output", "tmp"}; 134 } 135 Init(const std::vector<std::pair<std::string,std::string>> & kwargs)136 void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { 137 param_.Init(kwargs); 138 } 139 GetParams()140 std::map<std::string, std::string> GetParams() const override { 141 return param_.__DICT__(); 142 } 143 InferShape(mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape,mxnet::ShapeVector * aux_shape)144 bool InferShape(mxnet::ShapeVector *in_shape, 145 mxnet::ShapeVector *out_shape, 146 mxnet::ShapeVector *aux_shape) const override { 147 using namespace mshadow; 148 CHECK_EQ(in_shape->size(), 2U) << "Input:[data, grid]"; 149 const mxnet::TShape &dshape = (*in_shape)[bs::kData]; 150 const mxnet::TShape &lshape = (*in_shape)[bs::kGrid]; 151 if (!shape_is_known(dshape)) return false; 152 CHECK_EQ(dshape.ndim(), 4U) \ 153 << "input data should be 4D in batch-num_filter-y-x"; 154 if (!shape_is_known(lshape)) return false; 155 CHECK_EQ(lshape.ndim(), 4U) \ 156 << "Sampler grid should be 4D in batch-2-y-x"; 157 CHECK_EQ(dshape[0], lshape[0]); 158 CHECK_EQ(lshape[1], 2U) << "incorrect grid shape[1], should be 2"; 159 // target height 160 CHECK_GT(lshape[2], 0U) \ 161 << "incorrect grid_shape: " << lshape[2]; 162 // target width 163 CHECK_GT(lshape[3], 0U) \ 164 << "incorrect grid_shape: " << lshape[3]; 165 out_shape->clear(); 166 // output_shape : (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]) 167 out_shape->push_back(dshape); 168 (*out_shape)[bs::kOut][2] = lshape[2]; 169 (*out_shape)[bs::kOut][3] = lshape[3]; 170 out_shape->push_back(Shape4(lshape[0], lshape[2], lshape[3], 2)); 171 return true; 172 } 173 InferType(std::vector<int> * in_type,std::vector<int> * out_type,std::vector<int> * aux_type)174 bool InferType(std::vector<int> *in_type, 175 std::vector<int> *out_type, 176 std::vector<int> *aux_type) const override { 177 int dtype = -1; 178 for (int type : *in_type) { 179 if (dtype == -1) { 180 dtype = type; 181 } else { 182 CHECK(type == dtype || 183 type == -1) << 184 "Non-uniform data type in BilinearSampler"; 185 } 186 } 187 if (dtype == -1) { 188 LOG(FATAL) << "Not enough information to infer type in BilinearSampler."; 189 return false; 190 } 191 size_t nin = this->ListArguments().size(); 192 in_type->clear(); 193 for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype); 194 size_t naux = this->ListAuxiliaryStates().size(); 195 aux_type->clear(); 196 for (size_t i = 0; i < naux; ++i) aux_type->push_back(dtype); 197 size_t nout = this->ListOutputs().size(); 198 out_type->clear(); 199 for (size_t i = 0; i < nout; ++i) out_type->push_back(dtype); 200 return true; 201 } 202 Copy()203 OperatorProperty* Copy() const override { 204 auto ptr = new BilinearSamplerProp(); 205 ptr->param_ = param_; 206 return ptr; 207 } 208 TypeString()209 std::string TypeString() const override { 210 return "BilinearSampler"; 211 } 212 DeclareBackwardDependency(const std::vector<int> & out_grad,const std::vector<int> & in_data,const std::vector<int> & out_data)213 std::vector<int> DeclareBackwardDependency( 214 const std::vector<int> &out_grad, 215 const std::vector<int> &in_data, 216 const std::vector<int> &out_data) const override { 217 return {out_grad[bs::kOut], 218 in_data[bs::kData], 219 out_data[bs::kTmp], 220 in_data[bs::kGrid]}; 221 } 222 CreateOperator(Context ctx)223 Operator* CreateOperator(Context ctx) const override { 224 LOG(FATAL) << "Not Implemented."; 225 return nullptr; 226 } 227 228 Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, 229 std::vector<int> *in_type) const override; 230 231 private: 232 BilinearSamplerParam param_; 233 }; // class BilinearSamplerProp 234 #endif // DMLC_USE_CXX11 235 } // namespace op 236 } // namespace mxnet 237 #endif // MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_ 238