1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file range.h 22 * \brief support generating a range vector 23 * \author Xingjian Shi 24 */ 25 #ifndef MSHADOW_EXTENSION_RANGE_H_ 26 #define MSHADOW_EXTENSION_RANGE_H_ 27 28 #include "../extension.h" 29 30 namespace mshadow { 31 namespace expr { 32 /*! 33 * \brief Generate a range vector similar to python: range(start, stop[, step][, repeat]). 34 If step is positive, the last element is the largest start + i * step less than stop 35 If step is negative, the last element is the smallest start + i * step greater than stop. 36 All elements are repeated for `repeat` times, e.g range(0, 4, 2, 3) --> 0, 0, 0, 2, 2, 2 37 * \tparam SrcExp type of lhs expression 38 * \tparam IndexExp type of index expression 39 * \tparam DType the type of elements 40 */ 41 template<typename DType> 42 struct RangeExp: 43 public Exp<RangeExp<DType>, DType, type::kMapper> { 44 const DType start_; 45 const DType stop_; 46 const DType step_; 47 const int repeat_; 48 /*! \brief constructor */ RangeExpRangeExp49 RangeExp(DType start, DType stop, DType step, int repeat) 50 : start_(start), stop_(stop), step_(step), repeat_(repeat) {} 51 }; 52 53 template<typename DType> 54 inline RangeExp<DType> 55 range(DType start, DType stop, DType step = 1, int repeat = 1) { 56 return RangeExp<DType>(start, stop, step, repeat); 57 } 58 59 //---------------------- 60 // Execution plan 61 //---------------------- 62 template<typename DType> 63 struct Plan<RangeExp<DType>, DType> { 64 public: 65 explicit Plan(const RangeExp<DType> &e) 66 : start_(e.start_), 67 stop_(e.stop_), 68 step_(e.step_), 69 repeat_(e.repeat_) { 70 } 71 MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { 72 return start_ + static_cast<DType>((static_cast<int>(x) / repeat_)) * step_; 73 } 74 75 private: 76 const DType start_; 77 const DType stop_; 78 const DType step_; 79 const int repeat_; 80 }; 81 82 template<typename DType> 83 inline Plan<RangeExp<DType>, DType> 84 MakePlan(const RangeExp<DType> &exp) { 85 return Plan<RangeExp<DType>, DType>(exp); 86 } 87 88 89 template<typename DType> 90 inline int RangeOutSize(DType start, DType stop, DType step, int repeat) { 91 return repeat * ((stop - start - 1) / step + 1); 92 } 93 94 template<> 95 inline int RangeOutSize<float>(float start, float stop, float step, int repeat) { 96 double d_start = static_cast<double>(start); 97 double d_stop = static_cast<double>(stop); 98 double d_step = static_cast<double>(step); 99 return repeat * static_cast<int>(ceil((d_stop - d_start) / d_step)); 100 } 101 102 template<> 103 inline int RangeOutSize<double>(double start, double stop, double step, int repeat) { 104 return repeat * static_cast<int>(ceil((stop - start) / step)); 105 } 106 107 108 template<int dim, typename DType> 109 struct ShapeCheck<dim, RangeExp<DType> > { 110 inline static Shape<dim> 111 Check(const RangeExp<DType> &t) { 112 CHECK(dim == 1) 113 << "RangeExp only support 1 dimension output, received " << dim; 114 CHECK(t.step_ != 0) 115 << "RangeExp does not support step=0, received " << t.step_; 116 CHECK(t.repeat_ > 0) 117 << "RangeExp only supports repeat > 0, received " << t.repeat_; 118 if (t.step_ > 0) { 119 CHECK(t.start_ < t.stop_) << "RangeExp does not support (start, stop, step) = " 120 << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")"; 121 } else { 122 CHECK(t.start_ > t.stop_) << "RangeExp does not support (start, stop, step)= " 123 << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")"; 124 } 125 return Shape1(RangeOutSize<DType>(t.start_, t.stop_, t.step_, t.repeat_)); 126 } 127 }; 128 129 template<typename DType> 130 struct ExpInfo<RangeExp<DType> > { 131 static const int kDim = 1; 132 static const int kDevMask = 0xffff; 133 }; 134 } // namespace expr 135 } // namespace mshadow 136 #endif // MSHADOW_EXTENSION_RANGE_H_ 137