1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*!
20  * \file tvm/node/container.h
21  * \brief Array/Map container in the DSL graph.
22  */
23 #ifndef TVM_NODE_CONTAINER_H_
24 #define TVM_NODE_CONTAINER_H_
25 
26 #include <type_traits>
27 #include <vector>
28 #include <initializer_list>
29 #include <unordered_map>
30 #include <utility>
31 #include <string>
32 #include "node.h"
33 #include "memory.h"
34 
35 namespace tvm {
36 
37 /*! \brief array node content in array */
38 class ArrayNode : public Node {
39  public:
40   /*! \brief the data content */
41   std::vector<ObjectRef> data;
42 
VisitAttrs(AttrVisitor * visitor)43   void VisitAttrs(AttrVisitor* visitor) {
44   }
45 
46   static constexpr const char* _type_key = "Array";
47   TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Node);
48 };
49 
50 /*! \brief map node content */
51 class MapNode : public Node {
52  public:
VisitAttrs(AttrVisitor * visitor)53   void VisitAttrs(AttrVisitor* visitor) {
54   }
55 
56   /*! \brief The corresponding conatiner type */
57   using ContainerType = std::unordered_map<
58     ObjectRef,
59     ObjectRef,
60     ObjectHash, ObjectEqual>;
61 
62   /*! \brief the data content */
63   ContainerType data;
64 
65   static constexpr const char* _type_key = "Map";
66   TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Node);
67 };
68 
69 
70 /*! \brief specialized map node with string as key */
71 class StrMapNode : public Node {
72  public:
73   /*! \brief The corresponding conatiner type */
74   using ContainerType = std::unordered_map<std::string, ObjectRef>;
75 
VisitAttrs(AttrVisitor * visitor)76   void VisitAttrs(AttrVisitor* visitor) {
77   }
78 
79   /*! \brief the data content */
80   ContainerType data;
81 
82   static constexpr const char* _type_key = "StrMap";
83   TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Node);
84 };
85 
86 /*!
87  * \brief iterator adapter that adapts TIter to return another type.
88  * \tparam Converter a struct that contains converting function
89  * \tparam TIter the content iterator type.
90  */
91 template<typename Converter,
92          typename TIter>
93 class IterAdapter {
94  public:
95   using difference_type = typename std::iterator_traits<TIter>::difference_type;
96   using value_type = typename Converter::ResultType;
97   using pointer = typename Converter::ResultType*;
98   using reference = typename Converter::ResultType&;   // NOLINT(*)
99   using iterator_category = typename std::iterator_traits<TIter>::iterator_category;
100 
IterAdapter(TIter iter)101   explicit IterAdapter(TIter iter) : iter_(iter) {}
102   inline IterAdapter& operator++() {
103     ++iter_;
104     return *this;
105   }
106   inline IterAdapter operator+(difference_type offset) const {
107     return IterAdapter(iter_ + offset);
108   }
109 
110   template<typename T = IterAdapter>
111   typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
112                           typename T::difference_type>::type
113   inline operator-(const IterAdapter& rhs) const {
114     return iter_ - rhs.iter_;
115   }
116 
117   inline bool operator==(IterAdapter other) const {
118     return iter_ == other.iter_;
119   }
120   inline bool operator!=(IterAdapter other) const {
121     return !(*this == other);
122   }
123   inline const value_type operator*() const {
124     return Converter::convert(*iter_);
125   }
126 
127  private:
128   TIter iter_;
129 };
130 
131 /*!
132  * \brief Array container of NodeRef in DSL graph.
133  *  Array implements copy on write semantics, which means array is mutable
134  *  but copy will happen when array is referenced in more than two places.
135  *
136  * operator[] only provide const acces, use Set to mutate the content.
137  * \tparam T The content NodeRef type.
138  */
139 template<typename T,
140          typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type >
141 class Array : public NodeRef {
142  public:
143   /*!
144    * \brief default constructor
145    */
Array()146   Array() {
147     data_ = make_node<ArrayNode>();
148   }
149   /*!
150    * \brief move constructor
151    * \param other source
152    */
Array(Array<T> && other)153   Array(Array<T> && other) {  // NOLINT(*)
154     data_ = std::move(other.data_);
155   }
156   /*!
157    * \brief copy constructor
158    * \param other source
159    */
Array(const Array<T> & other)160   Array(const Array<T> &other) { // NOLINT(*)
161     data_ = std::move(other.data_);
162   }
163   /*!
164    * \brief constructor from pointer
165    * \param n the container pointer
166    */
Array(ObjectPtr<Object> n)167   explicit Array(ObjectPtr<Object> n) : NodeRef(n) {}
168   /*!
169    * \brief constructor from iterator
170    * \param begin begin of iterator
171    * \param end end of iterator
172    * \tparam IterType The type of iterator
173    */
174   template<typename IterType>
Array(IterType begin,IterType end)175   Array(IterType begin, IterType end) {
176     assign(begin, end);
177   }
178   /*!
179    * \brief constructor from initializer list
180    * \param init The initalizer list
181    */
Array(std::initializer_list<T> init)182   Array(std::initializer_list<T> init) { // NOLINT(*)
183     assign(init.begin(), init.end());
184   }
185   /*!
186    * \brief constructor from vector
187    * \param init The vector
188    */
Array(const std::vector<T> & init)189   Array(const std::vector<T>& init) { // NOLINT(*)
190     assign(init.begin(), init.end());
191   }
192   /*!
193    * \brief Constructs a container with n elements. Each element is a copy of val
194    * \param n The size of the container
195    * \param val The init value
196    */
Array(size_t n,const T & val)197   explicit Array(size_t n, const T& val) {
198     auto tmp_node = make_node<ArrayNode>();
199     for (size_t i = 0; i < n; ++i) {
200       tmp_node->data.push_back(val);
201     }
202     data_ = std::move(tmp_node);
203   }
204   /*!
205    * \brief move assign operator
206    * \param other The source of assignment
207    * \return reference to self.
208    */
209   Array<T>& operator=(Array<T> && other) {
210     data_ = std::move(other.data_);
211     return *this;
212   }
213   /*!
214    * \brief copy assign operator
215    * \param other The source of assignment
216    * \return reference to self.
217    */
218   Array<T>& operator=(const Array<T> & other) {
219     data_ = other.data_;
220     return *this;
221   }
222   /*!
223    * \brief reset the array to content from iterator.
224    * \param begin begin of iterator
225    * \param end end of iterator
226    * \tparam IterType The type of iterator
227    */
228   template<typename IterType>
assign(IterType begin,IterType end)229   void assign(IterType begin, IterType end) {
230     auto n = make_node<ArrayNode>();
231     for (IterType it = begin; it != end; ++it) {
232       n->data.push_back(T(*it));
233     }
234     data_ = std::move(n);
235   }
236   /*!
237    * \brief Read i-th element from array.
238    * \param i The index
239    * \return the i-th element.
240    */
241   inline const T operator[](size_t i) const {
242     return DowncastNoCheck<T>(
243         static_cast<const ArrayNode*>(data_.get())->data[i]);
244   }
245   /*! \return The size of the array */
size()246   inline size_t size() const {
247     if (data_.get() == nullptr) return 0;
248     return static_cast<const ArrayNode*>(data_.get())->data.size();
249   }
250   /*!
251    * \brief copy on write semantics
252    *  Do nothing if current handle is the unique copy of the array.
253    *  Otherwise make a new copy of the array to ensure the current handle
254    *  hold a unique copy.
255    *
256    * \return Handle to the internal node container(which ganrantees to be unique)
257    */
CopyOnWrite()258   inline ArrayNode* CopyOnWrite() {
259     if (data_.get() == nullptr || !data_.unique())  {
260       NodePtr<ArrayNode> n = make_node<ArrayNode>();
261       n->data = static_cast<ArrayNode*>(data_.get())->data;
262       ObjectPtr<Object>(std::move(n)).swap(data_);
263     }
264     return static_cast<ArrayNode*>(data_.get());
265   }
266   /*!
267    * \brief push a new item to the back of the list
268    * \param item The item to be pushed.
269    */
push_back(const T & item)270   inline void push_back(const T& item) {
271     ArrayNode* n = this->CopyOnWrite();
272     n->data.push_back(item);
273   }
274   /*!
275    * \brief set i-th element of the array.
276    * \param i The index
277    * \param value The value to be setted.
278    */
Set(size_t i,const T & value)279   inline void Set(size_t i, const T& value) {
280     ArrayNode* n = this->CopyOnWrite();
281     n->data[i] = value;
282   }
283   /*! \return whether array is empty */
empty()284   inline bool empty() const {
285     return size() == 0;
286   }
287   /*! \brief specify container node */
288   using ContainerType = ArrayNode;
289 
290   struct ValueConverter {
291     using ResultType = T;
convertValueConverter292     static inline T convert(const ObjectRef& n) {
293       return DowncastNoCheck<T>(n);
294     }
295   };
296   using iterator = IterAdapter<ValueConverter,
297                                std::vector<ObjectRef>::const_iterator>;
298 
299   using reverse_iterator = IterAdapter<
300     ValueConverter,
301     std::vector<ObjectRef>::const_reverse_iterator>;
302 
303   /*! \return begin iterator */
begin()304   inline iterator begin() const {
305     return iterator(static_cast<const ArrayNode*>(data_.get())->data.begin());
306   }
307   /*! \return end iterator */
end()308   inline iterator end() const {
309     return iterator(static_cast<const ArrayNode*>(data_.get())->data.end());
310   }
311   /*! \return rbegin iterator */
rbegin()312   inline reverse_iterator rbegin() const {
313     return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rbegin());
314   }
315   /*! \return rend iterator */
rend()316   inline reverse_iterator rend() const {
317     return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rend());
318   }
319 };
320 
321 /*!
322  * \brief Map container of NodeRef->NodeRef in DSL graph.
323  *  Map implements copy on write semantics, which means map is mutable
324  *  but copy will happen when array is referenced in more than two places.
325  *
326  * operator[] only provide const acces, use Set to mutate the content.
327  * \tparam K The key NodeRef type.
328  * \tparam V The value NodeRef type.
329  */
330 template<typename K,
331          typename V,
332          typename = typename std::enable_if<
333            std::is_base_of<NodeRef, K>::value ||
334            std::is_base_of<std::string, K>::value >::type,
335          typename = typename std::enable_if<std::is_base_of<NodeRef, V>::value>::type>
336 class Map : public NodeRef {
337  public:
338   /*!
339    * \brief default constructor
340    */
Map()341   Map() {
342     data_ = make_node<MapNode>();
343   }
344   /*!
345    * \brief move constructor
346    * \param other source
347    */
Map(Map<K,V> && other)348   Map(Map<K, V> && other) {  // NOLINT(*)
349     data_ = std::move(other.data_);
350   }
351   /*!
352    * \brief copy constructor
353    * \param other source
354    */
Map(const Map<K,V> & other)355   Map(const Map<K, V> &other) : NodeRef(other.data_) { // NOLINT(*)
356   }
357   /*!
358    * \brief constructor from pointer
359    * \param n the container pointer
360    */
Map(ObjectPtr<Object> n)361   explicit Map(ObjectPtr<Object> n) : NodeRef(n) {}
362   /*!
363    * \brief constructor from iterator
364    * \param begin begin of iterator
365    * \param end end of iterator
366    * \tparam IterType The type of iterator
367    */
368   template<typename IterType>
Map(IterType begin,IterType end)369   Map(IterType begin, IterType end) {
370     assign(begin, end);
371   }
372   /*!
373    * \brief constructor from initializer list
374    * \param init The initalizer list
375    */
Map(std::initializer_list<std::pair<K,V>> init)376   Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
377     assign(init.begin(), init.end());
378   }
379   /*!
380    * \brief constructor from vector
381    * \param init The vector
382    */
383   template<typename Hash, typename Equal>
Map(const std::unordered_map<K,V,Hash,Equal> & init)384   Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
385     assign(init.begin(), init.end());
386   }
387   /*!
388    * \brief move assign operator
389    * \param other The source of assignment
390    * \return reference to self.
391    */
392   Map<K, V>& operator=(Map<K, V> && other) {
393     data_ = std::move(other.data_);
394     return *this;
395   }
396   /*!
397    * \brief copy assign operator
398    * \param other The source of assignment
399    * \return reference to self.
400    */
401   Map<K, V>& operator=(const Map<K, V> & other) {
402     data_ = other.data_;
403     return *this;
404   }
405   /*!
406    * \brief reset the array to content from iterator.
407    * \param begin begin of iterator
408    * \param end end of iterator
409    * \tparam IterType The type of iterator
410    */
411   template<typename IterType>
assign(IterType begin,IterType end)412   void assign(IterType begin, IterType end) {
413     NodePtr<MapNode> n = make_node<MapNode>();
414     for (IterType i = begin; i != end; ++i) {
415       n->data.emplace(std::make_pair(i->first, i->second));
416     }
417     data_ = std::move(n);
418   }
419   /*!
420    * \brief Read element from map.
421    * \param key The key
422    * \return the corresonding element.
423    */
424   inline const V operator[](const K& key) const {
425     return DowncastNoCheck<V>(
426         static_cast<const MapNode*>(data_.get())->data.at(key));
427   }
428   /*!
429    * \brief Read element from map.
430    * \param key The key
431    * \return the corresonding element.
432    */
at(const K & key)433   inline const V at(const K& key) const {
434     return DowncastNoCheck<V>(
435         static_cast<const MapNode*>(data_.get())->data.at(key));
436   }
437   /*! \return The size of the array */
size()438   inline size_t size() const {
439     if (data_.get() == nullptr) return 0;
440     return static_cast<const MapNode*>(data_.get())->data.size();
441   }
442   /*! \return The number of elements of the key */
count(const K & key)443   inline size_t count(const K& key) const {
444     if (data_.get() == nullptr) return 0;
445     return static_cast<const MapNode*>(data_.get())->data.count(key);
446   }
447   /*!
448    * \brief copy on write semantics
449    *  Do nothing if current handle is the unique copy of the array.
450    *  Otherwise make a new copy of the array to ensure the current handle
451    *  hold a unique copy.
452    *
453    * \return Handle to the internal node container(which ganrantees to be unique)
454    */
CopyOnWrite()455   inline MapNode* CopyOnWrite() {
456     if (data_.get() == nullptr || !data_.unique())  {
457       NodePtr<MapNode> n = make_node<MapNode>();
458       n->data = static_cast<const MapNode*>(data_.get())->data;
459       ObjectPtr<Object>(std::move(n)).swap(data_);
460     }
461     return static_cast<MapNode*>(data_.get());
462   }
463   /*!
464    * \brief set the Map.
465    * \param key The index key.
466    * \param value The value to be setted.
467    */
Set(const K & key,const V & value)468   inline void Set(const K& key, const V& value) {
469     MapNode* n = this->CopyOnWrite();
470     n->data[key] = value;
471   }
472 
473   /*! \return whether array is empty */
empty()474   inline bool empty() const {
475     return size() == 0;
476   }
477   /*! \brief specify container node */
478   using ContainerType = MapNode;
479 
480   struct ValueConverter {
481     using ResultType = std::pair<K, V>;
convertValueConverter482     static inline ResultType convert(const std::pair<
483                                      ObjectRef,
484                                      ObjectRef>& n) {
485       return std::make_pair(DowncastNoCheck<K>(n.first),
486                             DowncastNoCheck<V>(n.second));
487     }
488   };
489 
490   using iterator = IterAdapter<
491     ValueConverter, MapNode::ContainerType::const_iterator>;
492 
493   /*! \return begin iterator */
begin()494   inline iterator begin() const {
495     return iterator(static_cast<const MapNode*>(data_.get())->data.begin());
496   }
497   /*! \return end iterator */
end()498   inline iterator end() const {
499     return iterator(static_cast<const MapNode*>(data_.get())->data.end());
500   }
501   /*! \return begin iterator */
find(const K & key)502   inline iterator find(const K& key) const {
503     return iterator(
504         static_cast<const MapNode*>(data_.get())->data.find(key));
505   }
506 };
507 
508 // specialize of string map
509 template<typename V, typename T1, typename T2>
510 class Map<std::string, V, T1, T2> : public NodeRef {
511  public:
512   // for code reuse
Map()513   Map() {
514     data_ = make_node<StrMapNode>();
515   }
Map(Map<std::string,V> && other)516   Map(Map<std::string, V> && other) {  // NOLINT(*)
517     data_ = std::move(other.data_);
518   }
Map(const Map<std::string,V> & other)519   Map(const Map<std::string, V> &other) : NodeRef(other.data_) { // NOLINT(*)
520   }
Map(ObjectPtr<Object> n)521   explicit Map(ObjectPtr<Object> n) : NodeRef(n) {}
522   template<typename IterType>
Map(IterType begin,IterType end)523   Map(IterType begin, IterType end) {
524     assign(begin, end);
525   }
Map(std::initializer_list<std::pair<std::string,V>> init)526   Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
527     assign(init.begin(), init.end());
528   }
529 
530   template<typename Hash, typename Equal>
Map(const std::unordered_map<std::string,V,Hash,Equal> & init)531   Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
532     assign(init.begin(), init.end());
533   }
534   Map<std::string, V>& operator=(Map<std::string, V> && other) {
535     data_ = std::move(other.data_);
536     return *this;
537   }
538   Map<std::string, V>& operator=(const Map<std::string, V> & other) {
539     data_ = other.data_;
540     return *this;
541   }
542   template<typename IterType>
assign(IterType begin,IterType end)543   void assign(IterType begin, IterType end) {
544     auto n = make_node<StrMapNode>();
545     for (IterType i = begin; i != end; ++i) {
546       n->data.emplace(std::make_pair(i->first, i->second));
547     }
548     data_ = std::move(n);
549   }
550   inline const V operator[](const std::string& key) const {
551     return DowncastNoCheck<V>(
552         static_cast<const StrMapNode*>(data_.get())->data.at(key));
553   }
at(const std::string & key)554   inline const V at(const std::string& key) const {
555     return DowncastNoCheck<V>(
556         static_cast<const StrMapNode*>(data_.get())->data.at(key));
557   }
size()558   inline size_t size() const {
559     if (data_.get() == nullptr) return 0;
560     return static_cast<const StrMapNode*>(data_.get())->data.size();
561   }
count(const std::string & key)562   inline size_t count(const std::string& key) const {
563     if (data_.get() == nullptr) return 0;
564     return static_cast<const StrMapNode*>(data_.get())->data.count(key);
565   }
CopyOnWrite()566   inline StrMapNode* CopyOnWrite() {
567     if (data_.get() == nullptr || !data_.unique())  {
568       NodePtr<StrMapNode> n = make_node<StrMapNode>();
569       n->data = static_cast<const StrMapNode*>(data_.get())->data;
570       ObjectPtr<Object>(std::move(n)).swap(data_);
571     }
572     return static_cast<StrMapNode*>(data_.get());
573   }
Set(const std::string & key,const V & value)574   inline void Set(const std::string& key, const V& value) {
575     StrMapNode* n = this->CopyOnWrite();
576     n->data[key] = value;
577   }
empty()578   inline bool empty() const {
579     return size() == 0;
580   }
581   using ContainerType = StrMapNode;
582 
583   struct ValueConverter {
584     using ResultType = std::pair<std::string, V>;
convertValueConverter585     static inline ResultType convert(const std::pair<
586                                      std::string,
587                                      ObjectRef>& n) {
588       return std::make_pair(n.first, DowncastNoCheck<V>(n.second));
589     }
590   };
591 
592   using iterator = IterAdapter<
593     ValueConverter, StrMapNode::ContainerType::const_iterator>;
594 
595   /*! \return begin iterator */
begin()596   inline iterator begin() const {
597     return iterator(static_cast<const StrMapNode*>(data_.get())->data.begin());
598   }
599   /*! \return end iterator */
end()600   inline iterator end() const {
601     return iterator(static_cast<const StrMapNode*>(data_.get())->data.end());
602   }
603   /*! \return begin iterator */
find(const std::string & key)604   inline iterator find(const std::string& key) const {
605     return iterator(static_cast<const StrMapNode*>(data_.get())->data.find(key));
606   }
607 };
608 
609 }  // namespace tvm
610 #endif  // TVM_NODE_CONTAINER_H_
611