1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file spatial_upsampling.h
22 * \brief
23 * \author Bing Xu
24 */
25 #ifndef MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_
26 #define MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_
27 #include "../extension.h"
28
29 namespace mshadow {
30 namespace expr {
31
32 /*! \brief nearest neighboor upsampling
33 * out(x, y) = in(int(x / scale_x), int(y / scale_y))
34 * \tparam SrcExp source expression
35 * \tparam DType data type
36 * \tparam srcdim source dimension
37 */
38 template<typename SrcExp, typename DType, int srcdim>
39 struct UpSamplingNearestExp :
40 public MakeTensorExp<UpSamplingNearestExp<SrcExp, DType, srcdim>,
41 SrcExp, srcdim, DType> {
42 /*! \brief source oprand */
43 const SrcExp &src_;
44 /*! \brief up sampling scale */
45 index_t scale_;
46 /*! \brief constructor */
UpSamplingNearestExpUpSamplingNearestExp47 UpSamplingNearestExp(const SrcExp &src, index_t scale)
48 : src_(src), scale_(scale) {
49 this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
50 this->shape_[srcdim - 2] *= scale_;
51 this->shape_[srcdim - 1] *= scale_;
52 }
53 };
54
55
56 template<typename SrcExp, typename DType, int etype>
57 inline UpSamplingNearestExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
upsampling_nearest(const Exp<SrcExp,DType,etype> & src,index_t scale)58 upsampling_nearest(const Exp<SrcExp, DType, etype> &src, index_t scale) {
59 TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
60 ::Error_Expression_Does_Not_Meet_Dimension_Req();
61 return UpSamplingNearestExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), scale);
62 }
63
64 template<typename SrcExp, typename DType, int srcdim>
65 struct Plan<UpSamplingNearestExp<SrcExp, DType, srcdim>, DType> {
66 public:
67 explicit Plan(const UpSamplingNearestExp<SrcExp, DType, srcdim> &e)
68 : src_(MakePlan(e.src_)),
69 scale_(e.scale_),
70 new_height_(e.shape_[srcdim - 2]),
71 src_height_(static_cast<index_t>(e.shape_[srcdim - 2] / e.scale_)) {}
72 MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
73 const index_t x = j;
74 const index_t y = i % new_height_;
75 const index_t c = i / new_height_;
76 const index_t h = static_cast<index_t>(y / scale_);
77 const index_t w = static_cast<index_t>(x / scale_);
78 return src_.Eval(c * src_height_ + h, w);
79 }
80
81 private:
82 Plan<SrcExp, DType> src_;
83 const index_t scale_;
84 const index_t new_height_;
85 const index_t src_height_;
86 };
87 } // namespace expr
88 } // namespace mshadow
89 #endif // MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_
90