1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21 * \file shape.h
22 * \brief definition of shape
23 * \author Chuntao Hong, Zhang Chen
24 */
25 
26 #ifndef MXNET_CPP_SHAPE_H_
27 #define MXNET_CPP_SHAPE_H_
28 
29 #include <istream>
30 #include <ostream>
31 #include <algorithm>
32 #include <vector>
33 #include "mxnet-cpp/base.h"
34 
35 namespace mxnet {
36 namespace cpp {
37 
38 /*!
39 * \brief dynamic shape class that can hold shape
40 *   of arbirary dimension
41 */
42 struct Shape {
43  public:
44   /*! \brief constructor */
ShapeShape45   Shape()
46     : ndim_(0),
47     num_heap_allocated_(0),
48     data_heap_(nullptr) {}
49   /*!
50   * \brief constructor from a vector of index_t
51   * \param v the vector
52   */
ShapeShape53   explicit Shape(const std::vector<index_t> &v)
54     : ndim_(v.size()) {
55     if (ndim_ <= kStackCache) {
56       data_heap_ = nullptr;
57       num_heap_allocated_ = 0;
58       std::copy(v.begin(), v.end(), data_stack_);
59     } else {
60       data_heap_ = new index_t[ndim_];
61       num_heap_allocated_ = ndim_;
62       std::copy(v.begin(), v.end(), data_heap_);
63     }
64   }
65   /*!
66   * \brief constructor one dimmension shape
67   * \param s1 size of the first dimmension
68   */
ShapeShape69   explicit Shape(index_t s1)
70     : ndim_(1) {
71     if (ndim_ <= kStackCache) {
72       data_heap_ = nullptr;
73       num_heap_allocated_ = 0;
74       data_stack_[0] = s1;
75     } else {
76       data_heap_ = new index_t[ndim_];
77       num_heap_allocated_ = ndim_;
78       data_heap_[0] = s1;
79     }
80   }
81   /*!
82   * \brief constructor two dimmension shape
83   * \param s1 size of the first dimmension
84   * \param s2 size of the second dimmension
85   */
ShapeShape86   Shape(index_t s1, index_t s2)
87     : ndim_(2) {
88     if (ndim_ <= kStackCache) {
89       data_heap_ = nullptr;
90       num_heap_allocated_ = 0;
91       data_stack_[0] = s1;
92       data_stack_[1] = s2;
93     } else {
94       data_heap_ = new index_t[ndim_];
95       num_heap_allocated_ = ndim_;
96       data_heap_[0] = s1;
97       data_heap_[1] = s2;
98     }
99   }
100   /*!
101   * \brief constructor three dimmension shape
102   * \param s1 size of the first dimmension
103   * \param s2 size of the second dimmension
104   * \param s3 size of the third dimmension
105   */
ShapeShape106   Shape(index_t s1, index_t s2, index_t s3)
107     : ndim_(3) {
108     if (ndim_ <= kStackCache) {
109       data_heap_ = nullptr;
110       num_heap_allocated_ = 0;
111       data_stack_[0] = s1;
112       data_stack_[1] = s2;
113       data_stack_[2] = s3;
114     } else {
115       data_heap_ = new index_t[ndim_];
116       num_heap_allocated_ = ndim_;
117       data_heap_[0] = s1;
118       data_heap_[1] = s2;
119       data_heap_[2] = s3;
120     }
121   }
122   /*!
123   * \brief constructor four dimmension shape
124   * \param s1 size of the first dimmension
125   * \param s2 size of the second dimmension
126   * \param s3 size of the third dimmension
127   * \param s4 size of the fourth dimmension
128   */
ShapeShape129   Shape(index_t s1, index_t s2, index_t s3, index_t s4)
130     : ndim_(4) {
131     if (ndim_ <= kStackCache) {
132       data_heap_ = nullptr;
133       num_heap_allocated_ = 0;
134       data_stack_[0] = s1;
135       data_stack_[1] = s2;
136       data_stack_[2] = s3;
137       data_stack_[3] = s4;
138     } else {
139       data_heap_ = new index_t[ndim_];
140       num_heap_allocated_ = ndim_;
141       data_heap_[0] = s1;
142       data_heap_[1] = s2;
143       data_heap_[2] = s3;
144       data_heap_[3] = s4;
145     }
146   }
147   /*!
148   * \brief constructor five dimmension shape
149   * \param s1 size of the first dimmension
150   * \param s2 size of the second dimmension
151   * \param s3 size of the third dimmension
152   * \param s4 size of the fourth dimmension
153   * \param s5 size of the fifth dimmension
154   */
ShapeShape155   Shape(index_t s1, index_t s2, index_t s3, index_t s4, index_t s5)
156     : ndim_(5) {
157     if (ndim_ <= kStackCache) {
158       data_heap_ = nullptr;
159       num_heap_allocated_ = 0;
160       data_stack_[0] = s1;
161       data_stack_[1] = s2;
162       data_stack_[2] = s3;
163       data_stack_[3] = s4;
164       data_stack_[4] = s5;
165     } else {
166       data_heap_ = new index_t[ndim_];
167       num_heap_allocated_ = ndim_;
168       data_heap_[0] = s1;
169       data_heap_[1] = s2;
170       data_heap_[2] = s3;
171       data_heap_[3] = s4;
172       data_heap_[4] = s5;
173     }
174   }
175   /*!
176   * \brief constructor from Shape
177   * \param s the source shape
178   */
ShapeShape179   Shape(const Shape &s)
180     : ndim_(s.ndim_) {
181     if (ndim_ <= kStackCache) {
182       data_heap_ = nullptr;
183       num_heap_allocated_ = 0;
184       std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
185     } else {
186       data_heap_ = new index_t[ndim_];
187       num_heap_allocated_ = ndim_;
188       std::copy(s.data_heap_, s.data_heap_ + ndim_, data_heap_);
189     }
190   }
191 #if MSHADOW_IN_CXX11
192   /*!
193   * \brief move constructor from Shape
194   * \param s the source shape
195   */
ShapeShape196   Shape(Shape &&s)
197     : ndim_(s.ndim_),
198     num_heap_allocated_(s.num_heap_allocated_),
199     data_heap_(s.data_heap_) {
200     if (ndim_ <= kStackCache) {
201       std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
202     }
203     // remove data heap space from s
204     s.data_heap_ = nullptr;
205   }
206 #endif
207   /*! \brief destructor */
~ShapeShape208   ~Shape() {
209     // data_heap_ can be nullptr
210     delete[] data_heap_;
211   }
212   /*!
213   * \brief copy shape from content betwen two iterators
214   * \param begin the beginning of iterator
215   * \param end the end of the iterator
216   * \tparam RandomAccessIterator iterator type
217   */
218   template<typename RandomAccessIterator>
CopyFromShape219   inline void CopyFrom(RandomAccessIterator begin,
220     RandomAccessIterator end) {
221     this->SetDim(end - begin);
222     std::copy(begin, end, data());
223   }
224   /*!
225   * \brief assignment from shape
226   * \param shape source shape
227   * \return reference of self
228   */
229   inline Shape &operator=(const Shape &shape) {
230     this->SetDim(shape.ndim_);
231     const index_t *src = shape.data();
232     std::copy(src, src + ndim_, data());
233     return *this;
234   }
235   /*!
236   * \brief assignment from vector
237   * \param shape source shape
238   * \return reference of self
239   */
240   inline Shape &operator=(const std::vector<index_t> &shape) {
241     this->CopyFrom(shape.begin(), shape.end());
242     return *this;
243   }
244   /*! \return the data content of the shape */
dataShape245   inline const index_t *data() const {
246     return ndim_ <= kStackCache ? data_stack_ : data_heap_;
247   }
248   /*! \return the data content of the shape */
dataShape249   inline index_t *data() {
250     return ndim_ <= kStackCache ? data_stack_ : data_heap_;
251   }
252   /*! \brief return number of dimension of the tensor inside */
ndimShape253   inline index_t ndim(void) const {
254     return ndim_;
255   }
256   /*!
257   * \brief get corresponding index
258   * \param i dimension index
259   * \return the corresponding dimension size
260   */
261   inline index_t &operator[](index_t i) {
262     return data()[i];
263   }
264   /*!
265   * \brief get corresponding index
266   * \param i dimension index
267   * \return the corresponding dimension size
268   */
269   inline const index_t &operator[](index_t i) const {
270     return data()[i];
271   }
272   /*! \brief total number of elements in the tensor */
SizeShape273   inline size_t Size(void) const {
274     size_t size = 1;
275     const index_t *d = this->data();
276     for (index_t i = 0; i < ndim_; ++i) {
277       size *= d[i];
278     }
279     return size;
280   }
281   /*!
282   * \return whether two shape equals
283   * \param s the shape to compare against
284   */
285   inline bool operator==(const Shape &s) const {
286     if (ndim_ != s.ndim_) return false;
287     if (ndim_ <= kStackCache) {
288       for (index_t i = 0; i < ndim_; ++i) {
289         if (data_stack_[i] != s.data_stack_[i]) return false;
290       }
291     } else {
292       for (index_t i = 0; i < ndim_; ++i) {
293         if (data_heap_[i] != s.data_heap_[i]) return false;
294       }
295     }
296     return true;
297   }
298   /*!
299   * \return whether two shape not equals
300   * \param s the shape to compare against
301   */
302   inline bool operator!=(const Shape &s) const {
303     return !(*this == s);
304   }
305 
306   friend std::ostream &operator<<(std::ostream &os, const Shape &shape);
307   friend std::istream &operator>>(std::istream &is, Shape &shape);
308 
309  private:
310   // the shape will be stored in data_stack_
311   // when dimension is smaller than kStackCache
312   // when it is bigger, it will be stored in data_heap_;
313   /*! \brief size of in stack space */
314   static const index_t kStackCache = 5;
315   /*! \brief number of dimnsion of the shape */
316   index_t ndim_;
317   /*! \brief number of cells allocated in data_heap_ */
318   index_t num_heap_allocated_;
319   /*! \brief in stack space used to store shape when it is small */
320   index_t data_stack_[kStackCache];
321   /*! \brief space to store shape when dimension is big*/
322   index_t *data_heap_;
323   /*!
324   * \brief internal function to set the dimension
325   * \param dim the dimension of the shape
326   */
SetDimShape327   inline void SetDim(index_t dim) {
328     if (dim > kStackCache &&
329       dim > num_heap_allocated_) {
330       // data_heap_ can be nullptr
331       delete[] data_heap_;
332       data_heap_ = new index_t[dim];
333       num_heap_allocated_ = dim;
334     }
335     ndim_ = dim;
336   }
337 };
338 
339 /*!
340 * \brief allow string printing of the shape
341 * \param os the output stream
342 * \param shape the shape
343 * \return the ostream
344 */
345 inline std::ostream &operator<<(std::ostream &os, const Shape &shape) {
346   os << '(';
347   for (index_t i = 0; i < shape.ndim(); ++i) {
348     if (i != 0) os << ',';
349     os << static_cast<int>(shape[i]);  // Supports negative Shape 'special codes' for inferring
350   }
351   // python style tuple
352   if (shape.ndim() == 1) os << ',';
353   os << ')';
354   return os;
355 }
356 
357 /*!
358 * \brief read shape from the istream
359 * \param is the input stream
360 * \param shape the shape
361 * \return the istream
362 */
363 inline std::istream &operator>>(std::istream &is, Shape &shape) {
364   // get (
365   while (true) {
366     char ch = is.get();
367     if (ch == '(') break;
368     if (!isspace(ch)) {
369       is.setstate(std::ios::failbit);
370       return is;
371     }
372   }
373   index_t idx;
374   std::vector<index_t> tmp;
375   while (is >> idx) {
376     tmp.push_back(idx);
377     char ch;
378     do {
379       ch = is.get();
380     } while (isspace(ch));
381     if (ch == ',') {
382       while (true) {
383         ch = is.peek();
384         if (isspace(ch)) {
385           is.get(); continue;
386         }
387         if (ch == ')') {
388           is.get(); break;
389         }
390         break;
391       }
392       if (ch == ')') break;
393     } else if (ch == ')') {
394       break;
395     } else {
396       is.setstate(std::ios::failbit);
397       return is;
398     }
399   }
400   shape.CopyFrom(tmp.begin(), tmp.end());
401   return is;
402 }
403 
404 }  // namespace cpp
405 }  // namespace mxnet
406 
407 #endif  // MXNET_CPP_SHAPE_H_
408