1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file range.h
22  * \brief support generating a range vector
23  * \author Xingjian Shi
24  */
25 #ifndef MSHADOW_EXTENSION_RANGE_H_
26 #define MSHADOW_EXTENSION_RANGE_H_
27 
28 #include "../extension.h"
29 
30 namespace mshadow {
31 namespace expr {
32 /*!
33  * \brief Generate a range vector similar to python: range(start, stop[, step][, repeat]).
34           If step is positive, the last element is the largest start + i * step less than stop
35           If step is negative, the last element is the smallest start + i * step greater than stop.
36           All elements are repeated for `repeat` times, e.g range(0, 4, 2, 3) --> 0, 0, 0, 2, 2, 2
37  * \tparam SrcExp type of lhs expression
38  * \tparam IndexExp type of index expression
39  * \tparam DType the type of elements
40  */
41 template<typename DType>
42 struct RangeExp:
43       public Exp<RangeExp<DType>, DType, type::kMapper> {
44   const DType start_;
45   const DType stop_;
46   const DType step_;
47   const int repeat_;
48   /*! \brief constructor */
RangeExpRangeExp49   RangeExp(DType start, DType stop, DType step, int repeat)
50       : start_(start), stop_(stop), step_(step), repeat_(repeat) {}
51 };
52 
53 template<typename DType>
54 inline RangeExp<DType>
55 range(DType start, DType stop, DType step = 1, int repeat = 1) {
56   return RangeExp<DType>(start, stop, step, repeat);
57 }
58 
59 //----------------------
60 // Execution plan
61 //----------------------
62 template<typename DType>
63 struct Plan<RangeExp<DType>, DType> {
64  public:
65   explicit Plan(const RangeExp<DType> &e)
66       : start_(e.start_),
67         stop_(e.stop_),
68         step_(e.step_),
69         repeat_(e.repeat_) {
70   }
71   MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
72     return start_ + static_cast<DType>((static_cast<int>(x) / repeat_)) * step_;
73   }
74 
75  private:
76   const DType start_;
77   const DType stop_;
78   const DType step_;
79   const int repeat_;
80 };
81 
82 template<typename DType>
83 inline Plan<RangeExp<DType>, DType>
84 MakePlan(const RangeExp<DType> &exp) {
85   return Plan<RangeExp<DType>, DType>(exp);
86 }
87 
88 
89 template<typename DType>
90 inline int RangeOutSize(DType start, DType stop, DType step, int repeat) {
91   return repeat * ((stop - start - 1) / step + 1);
92 }
93 
94 template<>
95 inline int RangeOutSize<float>(float start, float stop, float step, int repeat) {
96   double d_start = static_cast<double>(start);
97   double d_stop = static_cast<double>(stop);
98   double d_step = static_cast<double>(step);
99   return repeat * static_cast<int>(ceil((d_stop - d_start) / d_step));
100 }
101 
102 template<>
103 inline int RangeOutSize<double>(double start, double stop, double step, int repeat) {
104   return repeat * static_cast<int>(ceil((stop - start) / step));
105 }
106 
107 
108 template<int dim, typename DType>
109 struct ShapeCheck<dim, RangeExp<DType> > {
110   inline static Shape<dim>
111   Check(const RangeExp<DType> &t) {
112     CHECK(dim == 1)
113         << "RangeExp only support 1 dimension output, received " << dim;
114     CHECK(t.step_ != 0)
115         << "RangeExp does not support step=0, received " << t.step_;
116     CHECK(t.repeat_ > 0)
117       << "RangeExp only supports repeat > 0, received " << t.repeat_;
118     if (t.step_ > 0) {
119       CHECK(t.start_ < t.stop_) << "RangeExp does not support (start, stop, step) = "
120                                 << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")";
121     } else {
122       CHECK(t.start_ > t.stop_) << "RangeExp does not support (start, stop, step)= "
123                                 << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")";
124     }
125     return Shape1(RangeOutSize<DType>(t.start_, t.stop_, t.step_, t.repeat_));
126   }
127 };
128 
129 template<typename DType>
130 struct ExpInfo<RangeExp<DType> > {
131   static const int kDim = 1;
132   static const int kDevMask = 0xffff;
133 };
134 }  // namespace expr
135 }  // namespace mshadow
136 #endif  // MSHADOW_EXTENSION_RANGE_H_
137