1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 /*! 20 * \file tvm/node/container.h 21 * \brief Array/Map container in the DSL graph. 22 */ 23 #ifndef TVM_NODE_CONTAINER_H_ 24 #define TVM_NODE_CONTAINER_H_ 25 26 #include <type_traits> 27 #include <vector> 28 #include <initializer_list> 29 #include <unordered_map> 30 #include <utility> 31 #include <string> 32 #include "node.h" 33 #include "memory.h" 34 35 namespace tvm { 36 37 /*! \brief array node content in array */ 38 class ArrayNode : public Node { 39 public: 40 /*! \brief the data content */ 41 std::vector<ObjectRef> data; 42 VisitAttrs(AttrVisitor * visitor)43 void VisitAttrs(AttrVisitor* visitor) { 44 } 45 46 static constexpr const char* _type_key = "Array"; 47 TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Node); 48 }; 49 50 /*! \brief map node content */ 51 class MapNode : public Node { 52 public: VisitAttrs(AttrVisitor * visitor)53 void VisitAttrs(AttrVisitor* visitor) { 54 } 55 56 /*! \brief The corresponding conatiner type */ 57 using ContainerType = std::unordered_map< 58 ObjectRef, 59 ObjectRef, 60 ObjectHash, ObjectEqual>; 61 62 /*! \brief the data content */ 63 ContainerType data; 64 65 static constexpr const char* _type_key = "Map"; 66 TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Node); 67 }; 68 69 70 /*! \brief specialized map node with string as key */ 71 class StrMapNode : public Node { 72 public: 73 /*! \brief The corresponding conatiner type */ 74 using ContainerType = std::unordered_map<std::string, ObjectRef>; 75 VisitAttrs(AttrVisitor * visitor)76 void VisitAttrs(AttrVisitor* visitor) { 77 } 78 79 /*! \brief the data content */ 80 ContainerType data; 81 82 static constexpr const char* _type_key = "StrMap"; 83 TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Node); 84 }; 85 86 /*! 87 * \brief iterator adapter that adapts TIter to return another type. 88 * \tparam Converter a struct that contains converting function 89 * \tparam TIter the content iterator type. 90 */ 91 template<typename Converter, 92 typename TIter> 93 class IterAdapter { 94 public: 95 using difference_type = typename std::iterator_traits<TIter>::difference_type; 96 using value_type = typename Converter::ResultType; 97 using pointer = typename Converter::ResultType*; 98 using reference = typename Converter::ResultType&; // NOLINT(*) 99 using iterator_category = typename std::iterator_traits<TIter>::iterator_category; 100 IterAdapter(TIter iter)101 explicit IterAdapter(TIter iter) : iter_(iter) {} 102 inline IterAdapter& operator++() { 103 ++iter_; 104 return *this; 105 } 106 inline IterAdapter operator+(difference_type offset) const { 107 return IterAdapter(iter_ + offset); 108 } 109 110 template<typename T = IterAdapter> 111 typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value, 112 typename T::difference_type>::type 113 inline operator-(const IterAdapter& rhs) const { 114 return iter_ - rhs.iter_; 115 } 116 117 inline bool operator==(IterAdapter other) const { 118 return iter_ == other.iter_; 119 } 120 inline bool operator!=(IterAdapter other) const { 121 return !(*this == other); 122 } 123 inline const value_type operator*() const { 124 return Converter::convert(*iter_); 125 } 126 127 private: 128 TIter iter_; 129 }; 130 131 /*! 132 * \brief Array container of NodeRef in DSL graph. 133 * Array implements copy on write semantics, which means array is mutable 134 * but copy will happen when array is referenced in more than two places. 135 * 136 * operator[] only provide const acces, use Set to mutate the content. 137 * \tparam T The content NodeRef type. 138 */ 139 template<typename T, 140 typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type > 141 class Array : public NodeRef { 142 public: 143 /*! 144 * \brief default constructor 145 */ Array()146 Array() { 147 data_ = make_node<ArrayNode>(); 148 } 149 /*! 150 * \brief move constructor 151 * \param other source 152 */ Array(Array<T> && other)153 Array(Array<T> && other) { // NOLINT(*) 154 data_ = std::move(other.data_); 155 } 156 /*! 157 * \brief copy constructor 158 * \param other source 159 */ Array(const Array<T> & other)160 Array(const Array<T> &other) { // NOLINT(*) 161 data_ = std::move(other.data_); 162 } 163 /*! 164 * \brief constructor from pointer 165 * \param n the container pointer 166 */ Array(ObjectPtr<Object> n)167 explicit Array(ObjectPtr<Object> n) : NodeRef(n) {} 168 /*! 169 * \brief constructor from iterator 170 * \param begin begin of iterator 171 * \param end end of iterator 172 * \tparam IterType The type of iterator 173 */ 174 template<typename IterType> Array(IterType begin,IterType end)175 Array(IterType begin, IterType end) { 176 assign(begin, end); 177 } 178 /*! 179 * \brief constructor from initializer list 180 * \param init The initalizer list 181 */ Array(std::initializer_list<T> init)182 Array(std::initializer_list<T> init) { // NOLINT(*) 183 assign(init.begin(), init.end()); 184 } 185 /*! 186 * \brief constructor from vector 187 * \param init The vector 188 */ Array(const std::vector<T> & init)189 Array(const std::vector<T>& init) { // NOLINT(*) 190 assign(init.begin(), init.end()); 191 } 192 /*! 193 * \brief Constructs a container with n elements. Each element is a copy of val 194 * \param n The size of the container 195 * \param val The init value 196 */ Array(size_t n,const T & val)197 explicit Array(size_t n, const T& val) { 198 auto tmp_node = make_node<ArrayNode>(); 199 for (size_t i = 0; i < n; ++i) { 200 tmp_node->data.push_back(val); 201 } 202 data_ = std::move(tmp_node); 203 } 204 /*! 205 * \brief move assign operator 206 * \param other The source of assignment 207 * \return reference to self. 208 */ 209 Array<T>& operator=(Array<T> && other) { 210 data_ = std::move(other.data_); 211 return *this; 212 } 213 /*! 214 * \brief copy assign operator 215 * \param other The source of assignment 216 * \return reference to self. 217 */ 218 Array<T>& operator=(const Array<T> & other) { 219 data_ = other.data_; 220 return *this; 221 } 222 /*! 223 * \brief reset the array to content from iterator. 224 * \param begin begin of iterator 225 * \param end end of iterator 226 * \tparam IterType The type of iterator 227 */ 228 template<typename IterType> assign(IterType begin,IterType end)229 void assign(IterType begin, IterType end) { 230 auto n = make_node<ArrayNode>(); 231 for (IterType it = begin; it != end; ++it) { 232 n->data.push_back(T(*it)); 233 } 234 data_ = std::move(n); 235 } 236 /*! 237 * \brief Read i-th element from array. 238 * \param i The index 239 * \return the i-th element. 240 */ 241 inline const T operator[](size_t i) const { 242 return DowncastNoCheck<T>( 243 static_cast<const ArrayNode*>(data_.get())->data[i]); 244 } 245 /*! \return The size of the array */ size()246 inline size_t size() const { 247 if (data_.get() == nullptr) return 0; 248 return static_cast<const ArrayNode*>(data_.get())->data.size(); 249 } 250 /*! 251 * \brief copy on write semantics 252 * Do nothing if current handle is the unique copy of the array. 253 * Otherwise make a new copy of the array to ensure the current handle 254 * hold a unique copy. 255 * 256 * \return Handle to the internal node container(which ganrantees to be unique) 257 */ CopyOnWrite()258 inline ArrayNode* CopyOnWrite() { 259 if (data_.get() == nullptr || !data_.unique()) { 260 NodePtr<ArrayNode> n = make_node<ArrayNode>(); 261 n->data = static_cast<ArrayNode*>(data_.get())->data; 262 ObjectPtr<Object>(std::move(n)).swap(data_); 263 } 264 return static_cast<ArrayNode*>(data_.get()); 265 } 266 /*! 267 * \brief push a new item to the back of the list 268 * \param item The item to be pushed. 269 */ push_back(const T & item)270 inline void push_back(const T& item) { 271 ArrayNode* n = this->CopyOnWrite(); 272 n->data.push_back(item); 273 } 274 /*! 275 * \brief set i-th element of the array. 276 * \param i The index 277 * \param value The value to be setted. 278 */ Set(size_t i,const T & value)279 inline void Set(size_t i, const T& value) { 280 ArrayNode* n = this->CopyOnWrite(); 281 n->data[i] = value; 282 } 283 /*! \return whether array is empty */ empty()284 inline bool empty() const { 285 return size() == 0; 286 } 287 /*! \brief specify container node */ 288 using ContainerType = ArrayNode; 289 290 struct ValueConverter { 291 using ResultType = T; convertValueConverter292 static inline T convert(const ObjectRef& n) { 293 return DowncastNoCheck<T>(n); 294 } 295 }; 296 using iterator = IterAdapter<ValueConverter, 297 std::vector<ObjectRef>::const_iterator>; 298 299 using reverse_iterator = IterAdapter< 300 ValueConverter, 301 std::vector<ObjectRef>::const_reverse_iterator>; 302 303 /*! \return begin iterator */ begin()304 inline iterator begin() const { 305 return iterator(static_cast<const ArrayNode*>(data_.get())->data.begin()); 306 } 307 /*! \return end iterator */ end()308 inline iterator end() const { 309 return iterator(static_cast<const ArrayNode*>(data_.get())->data.end()); 310 } 311 /*! \return rbegin iterator */ rbegin()312 inline reverse_iterator rbegin() const { 313 return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rbegin()); 314 } 315 /*! \return rend iterator */ rend()316 inline reverse_iterator rend() const { 317 return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rend()); 318 } 319 }; 320 321 /*! 322 * \brief Map container of NodeRef->NodeRef in DSL graph. 323 * Map implements copy on write semantics, which means map is mutable 324 * but copy will happen when array is referenced in more than two places. 325 * 326 * operator[] only provide const acces, use Set to mutate the content. 327 * \tparam K The key NodeRef type. 328 * \tparam V The value NodeRef type. 329 */ 330 template<typename K, 331 typename V, 332 typename = typename std::enable_if< 333 std::is_base_of<NodeRef, K>::value || 334 std::is_base_of<std::string, K>::value >::type, 335 typename = typename std::enable_if<std::is_base_of<NodeRef, V>::value>::type> 336 class Map : public NodeRef { 337 public: 338 /*! 339 * \brief default constructor 340 */ Map()341 Map() { 342 data_ = make_node<MapNode>(); 343 } 344 /*! 345 * \brief move constructor 346 * \param other source 347 */ Map(Map<K,V> && other)348 Map(Map<K, V> && other) { // NOLINT(*) 349 data_ = std::move(other.data_); 350 } 351 /*! 352 * \brief copy constructor 353 * \param other source 354 */ Map(const Map<K,V> & other)355 Map(const Map<K, V> &other) : NodeRef(other.data_) { // NOLINT(*) 356 } 357 /*! 358 * \brief constructor from pointer 359 * \param n the container pointer 360 */ Map(ObjectPtr<Object> n)361 explicit Map(ObjectPtr<Object> n) : NodeRef(n) {} 362 /*! 363 * \brief constructor from iterator 364 * \param begin begin of iterator 365 * \param end end of iterator 366 * \tparam IterType The type of iterator 367 */ 368 template<typename IterType> Map(IterType begin,IterType end)369 Map(IterType begin, IterType end) { 370 assign(begin, end); 371 } 372 /*! 373 * \brief constructor from initializer list 374 * \param init The initalizer list 375 */ Map(std::initializer_list<std::pair<K,V>> init)376 Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*) 377 assign(init.begin(), init.end()); 378 } 379 /*! 380 * \brief constructor from vector 381 * \param init The vector 382 */ 383 template<typename Hash, typename Equal> Map(const std::unordered_map<K,V,Hash,Equal> & init)384 Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*) 385 assign(init.begin(), init.end()); 386 } 387 /*! 388 * \brief move assign operator 389 * \param other The source of assignment 390 * \return reference to self. 391 */ 392 Map<K, V>& operator=(Map<K, V> && other) { 393 data_ = std::move(other.data_); 394 return *this; 395 } 396 /*! 397 * \brief copy assign operator 398 * \param other The source of assignment 399 * \return reference to self. 400 */ 401 Map<K, V>& operator=(const Map<K, V> & other) { 402 data_ = other.data_; 403 return *this; 404 } 405 /*! 406 * \brief reset the array to content from iterator. 407 * \param begin begin of iterator 408 * \param end end of iterator 409 * \tparam IterType The type of iterator 410 */ 411 template<typename IterType> assign(IterType begin,IterType end)412 void assign(IterType begin, IterType end) { 413 NodePtr<MapNode> n = make_node<MapNode>(); 414 for (IterType i = begin; i != end; ++i) { 415 n->data.emplace(std::make_pair(i->first, i->second)); 416 } 417 data_ = std::move(n); 418 } 419 /*! 420 * \brief Read element from map. 421 * \param key The key 422 * \return the corresonding element. 423 */ 424 inline const V operator[](const K& key) const { 425 return DowncastNoCheck<V>( 426 static_cast<const MapNode*>(data_.get())->data.at(key)); 427 } 428 /*! 429 * \brief Read element from map. 430 * \param key The key 431 * \return the corresonding element. 432 */ at(const K & key)433 inline const V at(const K& key) const { 434 return DowncastNoCheck<V>( 435 static_cast<const MapNode*>(data_.get())->data.at(key)); 436 } 437 /*! \return The size of the array */ size()438 inline size_t size() const { 439 if (data_.get() == nullptr) return 0; 440 return static_cast<const MapNode*>(data_.get())->data.size(); 441 } 442 /*! \return The number of elements of the key */ count(const K & key)443 inline size_t count(const K& key) const { 444 if (data_.get() == nullptr) return 0; 445 return static_cast<const MapNode*>(data_.get())->data.count(key); 446 } 447 /*! 448 * \brief copy on write semantics 449 * Do nothing if current handle is the unique copy of the array. 450 * Otherwise make a new copy of the array to ensure the current handle 451 * hold a unique copy. 452 * 453 * \return Handle to the internal node container(which ganrantees to be unique) 454 */ CopyOnWrite()455 inline MapNode* CopyOnWrite() { 456 if (data_.get() == nullptr || !data_.unique()) { 457 NodePtr<MapNode> n = make_node<MapNode>(); 458 n->data = static_cast<const MapNode*>(data_.get())->data; 459 ObjectPtr<Object>(std::move(n)).swap(data_); 460 } 461 return static_cast<MapNode*>(data_.get()); 462 } 463 /*! 464 * \brief set the Map. 465 * \param key The index key. 466 * \param value The value to be setted. 467 */ Set(const K & key,const V & value)468 inline void Set(const K& key, const V& value) { 469 MapNode* n = this->CopyOnWrite(); 470 n->data[key] = value; 471 } 472 473 /*! \return whether array is empty */ empty()474 inline bool empty() const { 475 return size() == 0; 476 } 477 /*! \brief specify container node */ 478 using ContainerType = MapNode; 479 480 struct ValueConverter { 481 using ResultType = std::pair<K, V>; convertValueConverter482 static inline ResultType convert(const std::pair< 483 ObjectRef, 484 ObjectRef>& n) { 485 return std::make_pair(DowncastNoCheck<K>(n.first), 486 DowncastNoCheck<V>(n.second)); 487 } 488 }; 489 490 using iterator = IterAdapter< 491 ValueConverter, MapNode::ContainerType::const_iterator>; 492 493 /*! \return begin iterator */ begin()494 inline iterator begin() const { 495 return iterator(static_cast<const MapNode*>(data_.get())->data.begin()); 496 } 497 /*! \return end iterator */ end()498 inline iterator end() const { 499 return iterator(static_cast<const MapNode*>(data_.get())->data.end()); 500 } 501 /*! \return begin iterator */ find(const K & key)502 inline iterator find(const K& key) const { 503 return iterator( 504 static_cast<const MapNode*>(data_.get())->data.find(key)); 505 } 506 }; 507 508 // specialize of string map 509 template<typename V, typename T1, typename T2> 510 class Map<std::string, V, T1, T2> : public NodeRef { 511 public: 512 // for code reuse Map()513 Map() { 514 data_ = make_node<StrMapNode>(); 515 } Map(Map<std::string,V> && other)516 Map(Map<std::string, V> && other) { // NOLINT(*) 517 data_ = std::move(other.data_); 518 } Map(const Map<std::string,V> & other)519 Map(const Map<std::string, V> &other) : NodeRef(other.data_) { // NOLINT(*) 520 } Map(ObjectPtr<Object> n)521 explicit Map(ObjectPtr<Object> n) : NodeRef(n) {} 522 template<typename IterType> Map(IterType begin,IterType end)523 Map(IterType begin, IterType end) { 524 assign(begin, end); 525 } Map(std::initializer_list<std::pair<std::string,V>> init)526 Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*) 527 assign(init.begin(), init.end()); 528 } 529 530 template<typename Hash, typename Equal> Map(const std::unordered_map<std::string,V,Hash,Equal> & init)531 Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*) 532 assign(init.begin(), init.end()); 533 } 534 Map<std::string, V>& operator=(Map<std::string, V> && other) { 535 data_ = std::move(other.data_); 536 return *this; 537 } 538 Map<std::string, V>& operator=(const Map<std::string, V> & other) { 539 data_ = other.data_; 540 return *this; 541 } 542 template<typename IterType> assign(IterType begin,IterType end)543 void assign(IterType begin, IterType end) { 544 auto n = make_node<StrMapNode>(); 545 for (IterType i = begin; i != end; ++i) { 546 n->data.emplace(std::make_pair(i->first, i->second)); 547 } 548 data_ = std::move(n); 549 } 550 inline const V operator[](const std::string& key) const { 551 return DowncastNoCheck<V>( 552 static_cast<const StrMapNode*>(data_.get())->data.at(key)); 553 } at(const std::string & key)554 inline const V at(const std::string& key) const { 555 return DowncastNoCheck<V>( 556 static_cast<const StrMapNode*>(data_.get())->data.at(key)); 557 } size()558 inline size_t size() const { 559 if (data_.get() == nullptr) return 0; 560 return static_cast<const StrMapNode*>(data_.get())->data.size(); 561 } count(const std::string & key)562 inline size_t count(const std::string& key) const { 563 if (data_.get() == nullptr) return 0; 564 return static_cast<const StrMapNode*>(data_.get())->data.count(key); 565 } CopyOnWrite()566 inline StrMapNode* CopyOnWrite() { 567 if (data_.get() == nullptr || !data_.unique()) { 568 NodePtr<StrMapNode> n = make_node<StrMapNode>(); 569 n->data = static_cast<const StrMapNode*>(data_.get())->data; 570 ObjectPtr<Object>(std::move(n)).swap(data_); 571 } 572 return static_cast<StrMapNode*>(data_.get()); 573 } Set(const std::string & key,const V & value)574 inline void Set(const std::string& key, const V& value) { 575 StrMapNode* n = this->CopyOnWrite(); 576 n->data[key] = value; 577 } empty()578 inline bool empty() const { 579 return size() == 0; 580 } 581 using ContainerType = StrMapNode; 582 583 struct ValueConverter { 584 using ResultType = std::pair<std::string, V>; convertValueConverter585 static inline ResultType convert(const std::pair< 586 std::string, 587 ObjectRef>& n) { 588 return std::make_pair(n.first, DowncastNoCheck<V>(n.second)); 589 } 590 }; 591 592 using iterator = IterAdapter< 593 ValueConverter, StrMapNode::ContainerType::const_iterator>; 594 595 /*! \return begin iterator */ begin()596 inline iterator begin() const { 597 return iterator(static_cast<const StrMapNode*>(data_.get())->data.begin()); 598 } 599 /*! \return end iterator */ end()600 inline iterator end() const { 601 return iterator(static_cast<const StrMapNode*>(data_.get())->data.end()); 602 } 603 /*! \return begin iterator */ find(const std::string & key)604 inline iterator find(const std::string& key) const { 605 return iterator(static_cast<const StrMapNode*>(data_.get())->data.find(key)); 606 } 607 }; 608 609 } // namespace tvm 610 #endif // TVM_NODE_CONTAINER_H_ 611