1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file shape.h 22 * \brief definition of shape 23 * \author Chuntao Hong, Zhang Chen 24 */ 25 26 #ifndef MXNET_CPP_SHAPE_H_ 27 #define MXNET_CPP_SHAPE_H_ 28 29 #include <istream> 30 #include <ostream> 31 #include <algorithm> 32 #include <vector> 33 #include "mxnet-cpp/base.h" 34 35 namespace mxnet { 36 namespace cpp { 37 38 /*! 39 * \brief dynamic shape class that can hold shape 40 * of arbirary dimension 41 */ 42 struct Shape { 43 public: 44 /*! \brief constructor */ ShapeShape45 Shape() 46 : ndim_(0), 47 num_heap_allocated_(0), 48 data_heap_(nullptr) {} 49 /*! 50 * \brief constructor from a vector of index_t 51 * \param v the vector 52 */ ShapeShape53 explicit Shape(const std::vector<index_t> &v) 54 : ndim_(v.size()) { 55 if (ndim_ <= kStackCache) { 56 data_heap_ = nullptr; 57 num_heap_allocated_ = 0; 58 std::copy(v.begin(), v.end(), data_stack_); 59 } else { 60 data_heap_ = new index_t[ndim_]; 61 num_heap_allocated_ = ndim_; 62 std::copy(v.begin(), v.end(), data_heap_); 63 } 64 } 65 /*! 66 * \brief constructor one dimmension shape 67 * \param s1 size of the first dimmension 68 */ ShapeShape69 explicit Shape(index_t s1) 70 : ndim_(1) { 71 if (ndim_ <= kStackCache) { 72 data_heap_ = nullptr; 73 num_heap_allocated_ = 0; 74 data_stack_[0] = s1; 75 } else { 76 data_heap_ = new index_t[ndim_]; 77 num_heap_allocated_ = ndim_; 78 data_heap_[0] = s1; 79 } 80 } 81 /*! 82 * \brief constructor two dimmension shape 83 * \param s1 size of the first dimmension 84 * \param s2 size of the second dimmension 85 */ ShapeShape86 Shape(index_t s1, index_t s2) 87 : ndim_(2) { 88 if (ndim_ <= kStackCache) { 89 data_heap_ = nullptr; 90 num_heap_allocated_ = 0; 91 data_stack_[0] = s1; 92 data_stack_[1] = s2; 93 } else { 94 data_heap_ = new index_t[ndim_]; 95 num_heap_allocated_ = ndim_; 96 data_heap_[0] = s1; 97 data_heap_[1] = s2; 98 } 99 } 100 /*! 101 * \brief constructor three dimmension shape 102 * \param s1 size of the first dimmension 103 * \param s2 size of the second dimmension 104 * \param s3 size of the third dimmension 105 */ ShapeShape106 Shape(index_t s1, index_t s2, index_t s3) 107 : ndim_(3) { 108 if (ndim_ <= kStackCache) { 109 data_heap_ = nullptr; 110 num_heap_allocated_ = 0; 111 data_stack_[0] = s1; 112 data_stack_[1] = s2; 113 data_stack_[2] = s3; 114 } else { 115 data_heap_ = new index_t[ndim_]; 116 num_heap_allocated_ = ndim_; 117 data_heap_[0] = s1; 118 data_heap_[1] = s2; 119 data_heap_[2] = s3; 120 } 121 } 122 /*! 123 * \brief constructor four dimmension shape 124 * \param s1 size of the first dimmension 125 * \param s2 size of the second dimmension 126 * \param s3 size of the third dimmension 127 * \param s4 size of the fourth dimmension 128 */ ShapeShape129 Shape(index_t s1, index_t s2, index_t s3, index_t s4) 130 : ndim_(4) { 131 if (ndim_ <= kStackCache) { 132 data_heap_ = nullptr; 133 num_heap_allocated_ = 0; 134 data_stack_[0] = s1; 135 data_stack_[1] = s2; 136 data_stack_[2] = s3; 137 data_stack_[3] = s4; 138 } else { 139 data_heap_ = new index_t[ndim_]; 140 num_heap_allocated_ = ndim_; 141 data_heap_[0] = s1; 142 data_heap_[1] = s2; 143 data_heap_[2] = s3; 144 data_heap_[3] = s4; 145 } 146 } 147 /*! 148 * \brief constructor five dimmension shape 149 * \param s1 size of the first dimmension 150 * \param s2 size of the second dimmension 151 * \param s3 size of the third dimmension 152 * \param s4 size of the fourth dimmension 153 * \param s5 size of the fifth dimmension 154 */ ShapeShape155 Shape(index_t s1, index_t s2, index_t s3, index_t s4, index_t s5) 156 : ndim_(5) { 157 if (ndim_ <= kStackCache) { 158 data_heap_ = nullptr; 159 num_heap_allocated_ = 0; 160 data_stack_[0] = s1; 161 data_stack_[1] = s2; 162 data_stack_[2] = s3; 163 data_stack_[3] = s4; 164 data_stack_[4] = s5; 165 } else { 166 data_heap_ = new index_t[ndim_]; 167 num_heap_allocated_ = ndim_; 168 data_heap_[0] = s1; 169 data_heap_[1] = s2; 170 data_heap_[2] = s3; 171 data_heap_[3] = s4; 172 data_heap_[4] = s5; 173 } 174 } 175 /*! 176 * \brief constructor from Shape 177 * \param s the source shape 178 */ ShapeShape179 Shape(const Shape &s) 180 : ndim_(s.ndim_) { 181 if (ndim_ <= kStackCache) { 182 data_heap_ = nullptr; 183 num_heap_allocated_ = 0; 184 std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_); 185 } else { 186 data_heap_ = new index_t[ndim_]; 187 num_heap_allocated_ = ndim_; 188 std::copy(s.data_heap_, s.data_heap_ + ndim_, data_heap_); 189 } 190 } 191 #if MSHADOW_IN_CXX11 192 /*! 193 * \brief move constructor from Shape 194 * \param s the source shape 195 */ ShapeShape196 Shape(Shape &&s) 197 : ndim_(s.ndim_), 198 num_heap_allocated_(s.num_heap_allocated_), 199 data_heap_(s.data_heap_) { 200 if (ndim_ <= kStackCache) { 201 std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_); 202 } 203 // remove data heap space from s 204 s.data_heap_ = nullptr; 205 } 206 #endif 207 /*! \brief destructor */ ~ShapeShape208 ~Shape() { 209 // data_heap_ can be nullptr 210 delete[] data_heap_; 211 } 212 /*! 213 * \brief copy shape from content betwen two iterators 214 * \param begin the beginning of iterator 215 * \param end the end of the iterator 216 * \tparam RandomAccessIterator iterator type 217 */ 218 template<typename RandomAccessIterator> CopyFromShape219 inline void CopyFrom(RandomAccessIterator begin, 220 RandomAccessIterator end) { 221 this->SetDim(end - begin); 222 std::copy(begin, end, data()); 223 } 224 /*! 225 * \brief assignment from shape 226 * \param shape source shape 227 * \return reference of self 228 */ 229 inline Shape &operator=(const Shape &shape) { 230 this->SetDim(shape.ndim_); 231 const index_t *src = shape.data(); 232 std::copy(src, src + ndim_, data()); 233 return *this; 234 } 235 /*! 236 * \brief assignment from vector 237 * \param shape source shape 238 * \return reference of self 239 */ 240 inline Shape &operator=(const std::vector<index_t> &shape) { 241 this->CopyFrom(shape.begin(), shape.end()); 242 return *this; 243 } 244 /*! \return the data content of the shape */ dataShape245 inline const index_t *data() const { 246 return ndim_ <= kStackCache ? data_stack_ : data_heap_; 247 } 248 /*! \return the data content of the shape */ dataShape249 inline index_t *data() { 250 return ndim_ <= kStackCache ? data_stack_ : data_heap_; 251 } 252 /*! \brief return number of dimension of the tensor inside */ ndimShape253 inline index_t ndim(void) const { 254 return ndim_; 255 } 256 /*! 257 * \brief get corresponding index 258 * \param i dimension index 259 * \return the corresponding dimension size 260 */ 261 inline index_t &operator[](index_t i) { 262 return data()[i]; 263 } 264 /*! 265 * \brief get corresponding index 266 * \param i dimension index 267 * \return the corresponding dimension size 268 */ 269 inline const index_t &operator[](index_t i) const { 270 return data()[i]; 271 } 272 /*! \brief total number of elements in the tensor */ SizeShape273 inline size_t Size(void) const { 274 size_t size = 1; 275 const index_t *d = this->data(); 276 for (index_t i = 0; i < ndim_; ++i) { 277 size *= d[i]; 278 } 279 return size; 280 } 281 /*! 282 * \return whether two shape equals 283 * \param s the shape to compare against 284 */ 285 inline bool operator==(const Shape &s) const { 286 if (ndim_ != s.ndim_) return false; 287 if (ndim_ <= kStackCache) { 288 for (index_t i = 0; i < ndim_; ++i) { 289 if (data_stack_[i] != s.data_stack_[i]) return false; 290 } 291 } else { 292 for (index_t i = 0; i < ndim_; ++i) { 293 if (data_heap_[i] != s.data_heap_[i]) return false; 294 } 295 } 296 return true; 297 } 298 /*! 299 * \return whether two shape not equals 300 * \param s the shape to compare against 301 */ 302 inline bool operator!=(const Shape &s) const { 303 return !(*this == s); 304 } 305 306 friend std::ostream &operator<<(std::ostream &os, const Shape &shape); 307 friend std::istream &operator>>(std::istream &is, Shape &shape); 308 309 private: 310 // the shape will be stored in data_stack_ 311 // when dimension is smaller than kStackCache 312 // when it is bigger, it will be stored in data_heap_; 313 /*! \brief size of in stack space */ 314 static const index_t kStackCache = 5; 315 /*! \brief number of dimnsion of the shape */ 316 index_t ndim_; 317 /*! \brief number of cells allocated in data_heap_ */ 318 index_t num_heap_allocated_; 319 /*! \brief in stack space used to store shape when it is small */ 320 index_t data_stack_[kStackCache]; 321 /*! \brief space to store shape when dimension is big*/ 322 index_t *data_heap_; 323 /*! 324 * \brief internal function to set the dimension 325 * \param dim the dimension of the shape 326 */ SetDimShape327 inline void SetDim(index_t dim) { 328 if (dim > kStackCache && 329 dim > num_heap_allocated_) { 330 // data_heap_ can be nullptr 331 delete[] data_heap_; 332 data_heap_ = new index_t[dim]; 333 num_heap_allocated_ = dim; 334 } 335 ndim_ = dim; 336 } 337 }; 338 339 /*! 340 * \brief allow string printing of the shape 341 * \param os the output stream 342 * \param shape the shape 343 * \return the ostream 344 */ 345 inline std::ostream &operator<<(std::ostream &os, const Shape &shape) { 346 os << '('; 347 for (index_t i = 0; i < shape.ndim(); ++i) { 348 if (i != 0) os << ','; 349 os << static_cast<int>(shape[i]); // Supports negative Shape 'special codes' for inferring 350 } 351 // python style tuple 352 if (shape.ndim() == 1) os << ','; 353 os << ')'; 354 return os; 355 } 356 357 /*! 358 * \brief read shape from the istream 359 * \param is the input stream 360 * \param shape the shape 361 * \return the istream 362 */ 363 inline std::istream &operator>>(std::istream &is, Shape &shape) { 364 // get ( 365 while (true) { 366 char ch = is.get(); 367 if (ch == '(') break; 368 if (!isspace(ch)) { 369 is.setstate(std::ios::failbit); 370 return is; 371 } 372 } 373 index_t idx; 374 std::vector<index_t> tmp; 375 while (is >> idx) { 376 tmp.push_back(idx); 377 char ch; 378 do { 379 ch = is.get(); 380 } while (isspace(ch)); 381 if (ch == ',') { 382 while (true) { 383 ch = is.peek(); 384 if (isspace(ch)) { 385 is.get(); continue; 386 } 387 if (ch == ')') { 388 is.get(); break; 389 } 390 break; 391 } 392 if (ch == ')') break; 393 } else if (ch == ')') { 394 break; 395 } else { 396 is.setstate(std::ios::failbit); 397 return is; 398 } 399 } 400 shape.CopyFrom(tmp.begin(), tmp.end()); 401 return is; 402 } 403 404 } // namespace cpp 405 } // namespace mxnet 406 407 #endif // MXNET_CPP_SHAPE_H_ 408