1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file bilinear_Sampler-inl.h
22  * \brief
23  * \author Xu Dong
24 */
25 #ifndef MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_
26 #define MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_
27 
28 #include <dmlc/logging.h>
29 #include <dmlc/parameter.h>
30 #include <mxnet/operator.h>
31 #include <vector>
32 #include <map>
33 #include <string>
34 #include <utility>
35 #include "./operator_common.h"
36 
37 namespace mxnet {
38 namespace op {
39 
40 namespace bs {
41 enum BilinearSamplerOpInputs {kData, kGrid};
42 enum BilinearSamplerOpOutputs {kOut, kTmp};
43 }
44 
45 struct BilinearSamplerParam : public dmlc::Parameter<BilinearSamplerParam> {
46   dmlc::optional<bool> cudnn_off;
DMLC_DECLARE_PARAMETERBilinearSamplerParam47   DMLC_DECLARE_PARAMETER(BilinearSamplerParam) {
48     DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional<bool>())
49         .describe("whether to turn cudnn off");
50   }
51 };
52 
53 template<typename xpu, typename DType>
54 class BilinearSamplerOp : public Operator {
55  public:
BilinearSamplerOp(BilinearSamplerParam p)56   explicit BilinearSamplerOp(BilinearSamplerParam p) {
57     this->param_ = p;
58   }
59 
Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_args)60   virtual void Forward(const OpContext &ctx,
61                        const std::vector<TBlob> &in_data,
62                        const std::vector<OpReqType> &req,
63                        const std::vector<TBlob> &out_data,
64                        const std::vector<TBlob> &aux_args) {
65     using namespace mshadow;
66     using namespace mshadow::expr;
67     CHECK_EQ(req[bs::kOut], kWriteTo);
68     CHECK_EQ(in_data.size(), 2U);
69     Stream<xpu> *s = ctx.get_stream<xpu>();
70 
71     Tensor<xpu, 4, DType> data = in_data[bs::kData].get<xpu, 4, DType>(s);
72     Tensor<xpu, 4, DType> grid = in_data[bs::kGrid].get<xpu, 4, DType>(s);
73     Tensor<xpu, 4, DType> out = out_data[bs::kOut].get<xpu, 4, DType>(s);
74 
75     BilinearSamplerForward(out, data, grid);
76   }
77 
Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_args)78   virtual void Backward(const OpContext &ctx,
79                         const std::vector<TBlob> &out_grad,
80                         const std::vector<TBlob> &in_data,
81                         const std::vector<TBlob> &out_data,
82                         const std::vector<OpReqType> &req,
83                         const std::vector<TBlob> &in_grad,
84                         const std::vector<TBlob> &aux_args) {
85     using namespace mshadow;
86     using namespace mshadow::expr;
87     CHECK_EQ(in_data.size(), 2U);
88     CHECK_NE(req[bs::kData], kWriteInplace);
89     CHECK_NE(req[bs::kGrid], kWriteInplace);
90     Stream<xpu> *s = ctx.get_stream<xpu>();
91 
92     Tensor<xpu, 4, DType> data = in_data[bs::kData].get<xpu, 4, DType>(s);
93     Tensor<xpu, 4, DType> grid = in_data[bs::kGrid].get<xpu, 4, DType>(s);
94     Tensor<xpu, 4, DType> gdata = in_grad[bs::kData].get<xpu, 4, DType>(s);
95     Tensor<xpu, 4, DType> ggrid = in_grad[bs::kGrid].get<xpu, 4, DType>(s);
96     Tensor<xpu, 4, DType> grad = out_grad[bs::kOut].get<xpu, 4, DType>(s);
97     if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) {
98       return;
99     } else {
100       if (req[bs::kData] == kWriteTo) {
101         gdata = scalar<DType>(0.0f);
102       }
103       if (req[bs::kGrid] == kWriteTo) {
104         ggrid = scalar<DType>(0.0f);
105       }
106       BilinearSamplerBackward(gdata, ggrid, grad, data, grid, req[bs::kData], req[bs::kGrid]);
107     }
108   }
109 
110  private:
111   BilinearSamplerParam param_;
112 };  // class BilinearSamplerOp
113 
114 template<typename xpu>
115 Operator* CreateOp(BilinearSamplerParam param, int dtype);
116 
117 #if DMLC_USE_CXX11
118 class BilinearSamplerProp : public OperatorProperty {
119  public:
NumVisibleOutputs()120   int NumVisibleOutputs() const override {
121     return 1;
122   }
123 
NumOutputs()124   int NumOutputs() const override {
125     return 2;
126   }
127 
ListArguments()128   std::vector<std::string> ListArguments() const override {
129     return {"data", "grid"};
130   }
131 
ListOutputs()132   std::vector<std::string> ListOutputs() const override {
133     return {"output", "tmp"};
134   }
135 
Init(const std::vector<std::pair<std::string,std::string>> & kwargs)136   void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
137     param_.Init(kwargs);
138   }
139 
GetParams()140   std::map<std::string, std::string> GetParams() const override {
141     return param_.__DICT__();
142   }
143 
InferShape(mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape,mxnet::ShapeVector * aux_shape)144   bool InferShape(mxnet::ShapeVector *in_shape,
145                   mxnet::ShapeVector *out_shape,
146                   mxnet::ShapeVector *aux_shape) const override {
147     using namespace mshadow;
148     CHECK_EQ(in_shape->size(), 2U) << "Input:[data, grid]";
149     const mxnet::TShape &dshape = (*in_shape)[bs::kData];
150     const mxnet::TShape &lshape = (*in_shape)[bs::kGrid];
151     if (!shape_is_known(dshape)) return false;
152     CHECK_EQ(dshape.ndim(), 4U) \
153         << "input data should be 4D in batch-num_filter-y-x";
154     if (!shape_is_known(lshape)) return false;
155     CHECK_EQ(lshape.ndim(), 4U) \
156       << "Sampler grid should be 4D in batch-2-y-x";
157     CHECK_EQ(dshape[0], lshape[0]);
158     CHECK_EQ(lshape[1], 2U) << "incorrect grid shape[1], should be 2";
159     // target height
160     CHECK_GT(lshape[2], 0U) \
161             << "incorrect grid_shape: " << lshape[2];
162     // target width
163     CHECK_GT(lshape[3], 0U) \
164         << "incorrect grid_shape: " << lshape[3];
165     out_shape->clear();
166     // output_shape : (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3])
167     out_shape->push_back(dshape);
168     (*out_shape)[bs::kOut][2] = lshape[2];
169     (*out_shape)[bs::kOut][3] = lshape[3];
170     out_shape->push_back(Shape4(lshape[0], lshape[2], lshape[3], 2));
171     return true;
172   }
173 
InferType(std::vector<int> * in_type,std::vector<int> * out_type,std::vector<int> * aux_type)174   bool InferType(std::vector<int> *in_type,
175                    std::vector<int> *out_type,
176                    std::vector<int> *aux_type) const override {
177       int dtype = -1;
178       for (int type : *in_type) {
179         if (dtype == -1) {
180           dtype = type;
181         } else {
182           CHECK(type == dtype ||
183               type == -1) <<
184                 "Non-uniform data type in BilinearSampler";
185         }
186       }
187       if (dtype == -1) {
188         LOG(FATAL) << "Not enough information to infer type in BilinearSampler.";
189         return false;
190       }
191       size_t nin = this->ListArguments().size();
192       in_type->clear();
193       for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype);
194       size_t naux = this->ListAuxiliaryStates().size();
195       aux_type->clear();
196       for (size_t i = 0; i < naux; ++i) aux_type->push_back(dtype);
197       size_t nout = this->ListOutputs().size();
198       out_type->clear();
199       for (size_t i = 0; i < nout; ++i) out_type->push_back(dtype);
200       return true;
201     }
202 
Copy()203   OperatorProperty* Copy() const override {
204     auto ptr = new BilinearSamplerProp();
205     ptr->param_ = param_;
206     return ptr;
207   }
208 
TypeString()209   std::string TypeString() const override {
210     return "BilinearSampler";
211   }
212 
DeclareBackwardDependency(const std::vector<int> & out_grad,const std::vector<int> & in_data,const std::vector<int> & out_data)213   std::vector<int> DeclareBackwardDependency(
214     const std::vector<int> &out_grad,
215     const std::vector<int> &in_data,
216     const std::vector<int> &out_data) const override {
217     return {out_grad[bs::kOut],
218             in_data[bs::kData],
219             out_data[bs::kTmp],
220             in_data[bs::kGrid]};
221   }
222 
CreateOperator(Context ctx)223   Operator* CreateOperator(Context ctx) const override {
224     LOG(FATAL) << "Not Implemented.";
225     return nullptr;
226   }
227 
228   Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
229                              std::vector<int> *in_type) const override;
230 
231  private:
232   BilinearSamplerParam param_;
233 };  // class BilinearSamplerProp
234 #endif  // DMLC_USE_CXX11
235 }  // namespace op
236 }  // namespace mxnet
237 #endif  // MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_
238