1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tvm/runtime/container.h
22 * \brief Common POD(plain old data) container types.
23 */
24 #ifndef TVM_RUNTIME_CONTAINER_H_
25 #define TVM_RUNTIME_CONTAINER_H_
26
27 #include <dmlc/logging.h>
28 #include <tvm/runtime/memory.h>
29 #include <tvm/runtime/object.h>
30
31 #include <algorithm>
32 #include <cstring>
33 #include <initializer_list>
34 #include <memory>
35 #include <string>
36 #include <unordered_map>
37 // We use c++14 std::experimental::string_view for optimizing hash computation
38 // only right now, its usage is limited in this file. Any broader usage of
39 // std::experiment in our core codebase is discouraged and needs community
40 // discussion for each use case. Reference for feature test macros of
41 // string_view:
42 // https://isocpp.org/std/standing-documents/sd-6-sg10-feature-test-recommendations
43 // https://en.cppreference.com/w/User:D41D8CD98F/feature_testing_macros
44 #if defined(__cpp_lib_experimental_string_view) && __cpp_lib_experimental_string_view >= 201411
45 #define TVM_USE_CXX14_STRING_VIEW_HASH 1
46 #else
47 #define TVM_USE_CXX14_STRING_VIEW_HASH 0
48 #endif
49
50 // Tested with clang version 9.0.1 and c++17. It will detect string_view support
51 // correctly.
52 #if defined(__cpp_lib_string_view) && __cpp_lib_string_view >= 201606
53 #define TVM_USE_CXX17_STRING_VIEW_HASH 1
54 #else
55 #define TVM_USE_CXX17_STRING_VIEW_HASH 0
56 #endif
57
58 #if TVM_USE_CXX17_STRING_VIEW_HASH
59 #include <string_view>
60 #elif TVM_USE_CXX14_STRING_VIEW_HASH
61 #include <experimental/string_view>
62 #endif
63
64 #include <type_traits>
65 #include <utility>
66 #include <vector>
67
68 namespace llvm {
69 // String to llvm object compatibility.
70 class StringRef;
71 } // namespace llvm
72
73 namespace tvm {
74 namespace runtime {
75
76 // Forward declare TVMArgValue
77 class TVMArgValue;
78
79 /*! \brief String-aware ObjectRef equal functor */
80 struct ObjectHash {
81 /*!
82 * \brief Calculate the hash code of an ObjectRef
83 * \param a The given ObjectRef
84 * \return Hash code of a, string hash for strings and pointer address otherwise.
85 */
86 size_t operator()(const ObjectRef& a) const;
87 };
88
89 /*! \brief String-aware ObjectRef hash functor */
90 struct ObjectEqual {
91 /*!
92 * \brief Check if the two ObjectRef are equal
93 * \param a One ObjectRef
94 * \param b The other ObjectRef
95 * \return String equality if both are strings, pointer address equality otherwise.
96 */
97 bool operator()(const ObjectRef& a, const ObjectRef& b) const;
98 };
99
100 /*!
101 * \brief Base template for classes with array like memory layout.
102 *
103 * It provides general methods to access the memory. The memory
104 * layout is ArrayType + [ElemType]. The alignment of ArrayType
105 * and ElemType is handled by the memory allocator.
106 *
107 * \tparam ArrayType The array header type, contains object specific metadata.
108 * \tparam ElemType The type of objects stored in the array right after
109 * ArrayType.
110 *
111 * \code
112 * // Example usage of the template to define a simple array wrapper
113 * class ArrayObj : public InplaceArrayBase<ArrayObj, Elem> {
114 * public:
115 * // Wrap EmplaceInit to initialize the elements
116 * template <typename Iterator>
117 * void Init(Iterator begin, Iterator end) {
118 * size_t num_elems = std::distance(begin, end);
119 * auto it = begin;
120 * this->size = 0;
121 * for (size_t i = 0; i < num_elems; ++i) {
122 * InplaceArrayBase::EmplaceInit(i, *it++);
123 * this->size++;
124 * }
125 * }
126 * }
127 *
128 * void test_function() {
129 * vector<Elem> fields;
130 * auto ptr = make_inplace_array_object<ArrayObj, Elem>(fields.size());
131 * ptr->Init(fields.begin(), fields.end());
132 *
133 * // Access the 0th element in the array.
134 * assert(ptr->operator[](0) == fields[0]);
135 * }
136 *
137 * \endcode
138 */
139 template <typename ArrayType, typename ElemType>
140 class InplaceArrayBase {
141 public:
142 /*!
143 * \brief Access element at index
144 * \param idx The index of the element.
145 * \return Const reference to ElemType at the index.
146 */
147 const ElemType& operator[](size_t idx) const {
148 size_t size = Self()->GetSize();
149 CHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n";
150 return *(reinterpret_cast<ElemType*>(AddressOf(idx)));
151 }
152
153 /*!
154 * \brief Access element at index
155 * \param idx The index of the element.
156 * \return Reference to ElemType at the index.
157 */
158 ElemType& operator[](size_t idx) {
159 size_t size = Self()->GetSize();
160 CHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n";
161 return *(reinterpret_cast<ElemType*>(AddressOf(idx)));
162 }
163
164 /*!
165 * \brief Destroy the Inplace Array Base object
166 */
~InplaceArrayBase()167 ~InplaceArrayBase() {
168 if (!(std::is_standard_layout<ElemType>::value && std::is_trivial<ElemType>::value)) {
169 size_t size = Self()->GetSize();
170 for (size_t i = 0; i < size; ++i) {
171 ElemType* fp = reinterpret_cast<ElemType*>(AddressOf(i));
172 fp->ElemType::~ElemType();
173 }
174 }
175 }
176
177 protected:
178 /*!
179 * \brief Construct a value in place with the arguments.
180 *
181 * \tparam Args Type parameters of the arguments.
182 * \param idx Index of the element.
183 * \param args Arguments to construct the new value.
184 *
185 * \note Please make sure ArrayType::GetSize returns 0 before first call of
186 * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds.
187 */
188 template <typename... Args>
EmplaceInit(size_t idx,Args &&...args)189 void EmplaceInit(size_t idx, Args&&... args) {
190 void* field_ptr = AddressOf(idx);
191 new (field_ptr) ElemType(std::forward<Args>(args)...);
192 }
193
194 /*!
195 * \brief Return the self object for the array.
196 *
197 * \return Pointer to ArrayType.
198 */
Self()199 inline ArrayType* Self() const {
200 return static_cast<ArrayType*>(const_cast<InplaceArrayBase*>(this));
201 }
202
203 /*!
204 * \brief Return the raw pointer to the element at idx.
205 *
206 * \param idx The index of the element.
207 * \return Raw pointer to the element.
208 */
AddressOf(size_t idx)209 void* AddressOf(size_t idx) const {
210 static_assert(
211 alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0,
212 "The size and alignment of ArrayType should respect "
213 "ElemType's alignment.");
214
215 size_t kDataStart = sizeof(ArrayType);
216 ArrayType* self = Self();
217 char* data_start = reinterpret_cast<char*>(self) + kDataStart;
218 return data_start + idx * sizeof(ElemType);
219 }
220 };
221
222 /*!
223 * \brief iterator adapter that adapts TIter to return another type.
224 * \tparam Converter a struct that contains converting function
225 * \tparam TIter the content iterator type.
226 */
227 template <typename Converter, typename TIter>
228 class IterAdapter {
229 public:
230 using difference_type = typename std::iterator_traits<TIter>::difference_type;
231 using value_type = typename Converter::ResultType;
232 using pointer = typename Converter::ResultType*;
233 using reference = typename Converter::ResultType&;
234 using iterator_category = typename std::iterator_traits<TIter>::iterator_category;
235
IterAdapter(TIter iter)236 explicit IterAdapter(TIter iter) : iter_(iter) {}
237 IterAdapter& operator++() {
238 ++iter_;
239 return *this;
240 }
241 IterAdapter& operator--() {
242 --iter_;
243 return *this;
244 }
245 IterAdapter operator++(int) {
246 IterAdapter copy = *this;
247 ++iter_;
248 return copy;
249 }
250 IterAdapter operator--(int) {
251 IterAdapter copy = *this;
252 --iter_;
253 return copy;
254 }
255
256 IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); }
257
258 IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); }
259
260 template <typename T = IterAdapter>
261 typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
262 typename T::difference_type>::type inline
263 operator-(const IterAdapter& rhs) const {
264 return iter_ - rhs.iter_;
265 }
266
267 bool operator==(IterAdapter other) const { return iter_ == other.iter_; }
268 bool operator!=(IterAdapter other) const { return !(*this == other); }
269 const value_type operator*() const { return Converter::convert(*iter_); }
270
271 private:
272 TIter iter_;
273 };
274
275 /*!
276 * \brief iterator adapter that adapts TIter to return another type.
277 * \tparam Converter a struct that contains converting function
278 * \tparam TIter the content iterator type.
279 */
280 template <typename Converter, typename TIter>
281 class ReverseIterAdapter {
282 public:
283 using difference_type = typename std::iterator_traits<TIter>::difference_type;
284 using value_type = typename Converter::ResultType;
285 using pointer = typename Converter::ResultType*;
286 using reference = typename Converter::ResultType&; // NOLINT(*)
287 using iterator_category = typename std::iterator_traits<TIter>::iterator_category;
288
ReverseIterAdapter(TIter iter)289 explicit ReverseIterAdapter(TIter iter) : iter_(iter) {}
290 ReverseIterAdapter& operator++() {
291 --iter_;
292 return *this;
293 }
294 ReverseIterAdapter& operator--() {
295 ++iter_;
296 return *this;
297 }
298 ReverseIterAdapter& operator++(int) {
299 ReverseIterAdapter copy = *this;
300 --iter_;
301 return copy;
302 }
303 ReverseIterAdapter& operator--(int) {
304 ReverseIterAdapter copy = *this;
305 ++iter_;
306 return copy;
307 }
308 ReverseIterAdapter operator+(difference_type offset) const {
309 return ReverseIterAdapter(iter_ - offset);
310 }
311
312 template <typename T = ReverseIterAdapter>
313 typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
314 typename T::difference_type>::type inline
315 operator-(const ReverseIterAdapter& rhs) const {
316 return rhs.iter_ - iter_;
317 }
318
319 bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; }
320 bool operator!=(ReverseIterAdapter other) const { return !(*this == other); }
321 const value_type operator*() const { return Converter::convert(*iter_); }
322
323 private:
324 TIter iter_;
325 };
326
327 /*! \brief array node content in array */
328 class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
329 public:
330 /*! \return The size of the array */
size()331 size_t size() const { return this->size_; }
332
333 /*!
334 * \brief Read i-th element from array.
335 * \param i The index
336 * \return the i-th element.
337 */
at(int64_t i)338 const ObjectRef at(int64_t i) const { return this->operator[](i); }
339
340 /*! \return begin constant iterator */
begin()341 const ObjectRef* begin() const { return static_cast<ObjectRef*>(InplaceArrayBase::AddressOf(0)); }
342
343 /*! \return end constant iterator */
end()344 const ObjectRef* end() const { return begin() + size_; }
345
346 /*! \brief Release reference to all the elements */
clear()347 void clear() { ShrinkBy(size_); }
348
349 /*!
350 * \brief Set i-th element of the array in-place
351 * \param i The index
352 * \param item The value to be set
353 */
SetItem(int64_t i,ObjectRef item)354 void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); }
355
356 /*!
357 * \brief Constructs a container and copy from another
358 * \param cap The capacity of the container
359 * \param from Source of the copy
360 * \return Ref-counted ArrayNode requested
361 */
CopyFrom(int64_t cap,ArrayNode * from)362 static ObjectPtr<ArrayNode> CopyFrom(int64_t cap, ArrayNode* from) {
363 int64_t size = from->size_;
364 CHECK_GE(cap, size) << "ValueError: not enough capacity";
365 ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap);
366 ObjectRef* write = p->MutableBegin();
367 ObjectRef* read = from->MutableBegin();
368 // To ensure exception safety, size is only incremented after the initialization succeeds
369 for (int64_t& i = p->size_ = 0; i < size; ++i) {
370 new (write++) ObjectRef(*read++);
371 }
372 return p;
373 }
374
375 /*!
376 * \brief Constructs a container and move from another
377 * \param cap The capacity of the container
378 * \param from Source of the move
379 * \return Ref-counted ArrayNode requested
380 */
MoveFrom(int64_t cap,ArrayNode * from)381 static ObjectPtr<ArrayNode> MoveFrom(int64_t cap, ArrayNode* from) {
382 int64_t size = from->size_;
383 CHECK_GE(cap, size) << "ValueError: not enough capacity";
384 ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap);
385 ObjectRef* write = p->MutableBegin();
386 ObjectRef* read = from->MutableBegin();
387 // To ensure exception safety, size is only incremented after the initialization succeeds
388 for (int64_t& i = p->size_ = 0; i < size; ++i) {
389 new (write++) ObjectRef(std::move(*read++));
390 }
391 from->size_ = 0;
392 return p;
393 }
394
395 /*!
396 * \brief Constructs a container with n elements. Each element is a copy of val
397 * \param n The size of the container
398 * \param val The init value
399 * \return Ref-counted ArrayNode requested
400 */
CreateRepeated(int64_t n,const ObjectRef & val)401 static ObjectPtr<ArrayNode> CreateRepeated(int64_t n, const ObjectRef& val) {
402 ObjectPtr<ArrayNode> p = ArrayNode::Empty(n);
403 ObjectRef* itr = p->MutableBegin();
404 for (int64_t& i = p->size_ = 0; i < n; ++i) {
405 new (itr++) ObjectRef(val);
406 }
407 return p;
408 }
409
410 static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray;
411 static constexpr const char* _type_key = "Array";
412 TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
413
414 private:
415 /*! \return Size of initialized memory, used by InplaceArrayBase. */
GetSize()416 size_t GetSize() const { return this->size_; }
417
418 /*! \return begin mutable iterator */
MutableBegin()419 ObjectRef* MutableBegin() const {
420 return static_cast<ObjectRef*>(InplaceArrayBase::AddressOf(0));
421 }
422
423 /*! \return end mutable iterator */
MutableEnd()424 ObjectRef* MutableEnd() const { return MutableBegin() + size_; }
425
426 /*!
427 * \brief Create an ArrayNode with the given capacity.
428 * \param n Required capacity
429 * \return Ref-counted ArrayNode requested
430 */
431 static ObjectPtr<ArrayNode> Empty(int64_t n = kInitSize) {
432 CHECK_GE(n, 0);
433 ObjectPtr<ArrayNode> p = make_inplace_array_object<ArrayNode, ObjectRef>(n);
434 p->capacity_ = n;
435 p->size_ = 0;
436 return p;
437 }
438
439 /*!
440 * \brief Inplace-initialize the elements starting idx from [first, last)
441 * \param idx The starting point
442 * \param first Begin of iterator
443 * \param last End of iterator
444 * \tparam IterType The type of iterator
445 * \return Self
446 */
447 template <typename IterType>
InitRange(int64_t idx,IterType first,IterType last)448 ArrayNode* InitRange(int64_t idx, IterType first, IterType last) {
449 ObjectRef* itr = MutableBegin() + idx;
450 for (; first != last; ++first) {
451 ObjectRef ref = *first;
452 new (itr++) ObjectRef(std::move(ref));
453 }
454 return this;
455 }
456
457 /*!
458 * \brief Move elements from right to left, requires src_begin > dst
459 * \param dst Destination
460 * \param src_begin The start point of copy (inclusive)
461 * \param src_end The end point of copy (exclusive)
462 * \return Self
463 */
MoveElementsLeft(int64_t dst,int64_t src_begin,int64_t src_end)464 ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) {
465 ObjectRef* from = MutableBegin() + src_begin;
466 ObjectRef* to = MutableBegin() + dst;
467 while (src_begin++ != src_end) {
468 *to++ = std::move(*from++);
469 }
470 return this;
471 }
472
473 /*!
474 * \brief Move elements from left to right, requires src_begin < dst
475 * \param dst Destination
476 * \param src_begin The start point of move (inclusive)
477 * \param src_end The end point of move (exclusive)
478 * \return Self
479 */
MoveElementsRight(int64_t dst,int64_t src_begin,int64_t src_end)480 ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) {
481 ObjectRef* from = MutableBegin() + src_end;
482 ObjectRef* to = MutableBegin() + (src_end - src_begin + dst);
483 while (src_begin++ != src_end) {
484 *--to = std::move(*--from);
485 }
486 return this;
487 }
488
489 /*!
490 * \brief Enlarges the size of the array
491 * \param delta Size enlarged, should be positive
492 * \param val Default value
493 * \return Self
494 */
495 ArrayNode* EnlargeBy(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) {
496 ObjectRef* itr = MutableEnd();
497 while (delta-- > 0) {
498 new (itr++) ObjectRef(val);
499 ++size_;
500 }
501 return this;
502 }
503
504 /*!
505 * \brief Shrinks the size of the array
506 * \param delta Size shrinked, should be positive
507 * \return Self
508 */
ShrinkBy(int64_t delta)509 ArrayNode* ShrinkBy(int64_t delta) {
510 ObjectRef* itr = MutableEnd();
511 while (delta-- > 0) {
512 (--itr)->ObjectRef::~ObjectRef();
513 --size_;
514 }
515 return this;
516 }
517
518 /*! \brief Number of elements used */
519 int64_t size_;
520
521 /*! \brief Number of elements allocated */
522 int64_t capacity_;
523
524 /*! \brief Initial size of ArrayNode */
525 static constexpr int64_t kInitSize = 4;
526
527 /*! \brief Expansion factor of the Array */
528 static constexpr int64_t kIncFactor = 2;
529
530 // CRTP parent class
531 friend InplaceArrayBase<ArrayNode, ObjectRef>;
532
533 // Reference class
534 template <typename, typename>
535 friend class Array;
536
537 // To specialize make_object<ArrayNode>
538 friend ObjectPtr<ArrayNode> make_object<>();
539 };
540
541 /*!
542 * \brief Array, container representing a contigious sequence of ObjectRefs.
543 *
544 * Array implements in-place copy-on-write semantics.
545 *
546 * As in typical copy-on-write, a method which would typically mutate the array
547 * instead opaquely copies the underlying container, and then acts on its copy.
548 *
549 * If the array has reference count equal to one, we directly update the
550 * container in place without copying. This is optimization is sound because
551 * when the reference count is equal to one this reference is guranteed to be
552 * the sole pointer to the container.
553 *
554 *
555 * operator[] only provides const access, use Set to mutate the content.
556 * \tparam T The content ObjectRef type.
557 */
558 template <typename T,
559 typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
560 class Array : public ObjectRef {
561 public:
562 using value_type = T;
563 // constructors
564 /*!
565 * \brief default constructor
566 */
Array()567 Array() { data_ = ArrayNode::Empty(); }
568
569 /*!
570 * \brief move constructor
571 * \param other source
572 */
Array(Array<T> && other)573 Array(Array<T>&& other) : ObjectRef() { // NOLINT(*)
574 data_ = std::move(other.data_);
575 }
576
577 /*!
578 * \brief copy constructor
579 * \param other source
580 */
Array(const Array<T> & other)581 Array(const Array<T>& other) : ObjectRef() { // NOLINT(*)
582 data_ = other.data_;
583 }
584
585 /*!
586 * \brief constructor from pointer
587 * \param n the container pointer
588 */
Array(ObjectPtr<Object> n)589 explicit Array(ObjectPtr<Object> n) : ObjectRef(n) {}
590
591 /*!
592 * \brief Constructor from iterator
593 * \param first begin of iterator
594 * \param last end of iterator
595 * \tparam IterType The type of iterator
596 */
597 template <typename IterType>
Array(IterType first,IterType last)598 Array(IterType first, IterType last) {
599 Assign(first, last);
600 }
601
602 /*!
603 * \brief constructor from initializer list
604 * \param init The initializer list
605 */
Array(std::initializer_list<T> init)606 Array(std::initializer_list<T> init) { // NOLINT(*)
607 Assign(init.begin(), init.end());
608 }
609
610 /*!
611 * \brief constructor from vector
612 * \param init The vector
613 */
Array(const std::vector<T> & init)614 Array(const std::vector<T>& init) { // NOLINT(*)
615 Assign(init.begin(), init.end());
616 }
617
618 /*!
619 * \brief Constructs a container with n elements. Each element is a copy of val
620 * \param n The size of the container
621 * \param val The init value
622 */
Array(const size_t n,const T & val)623 explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); }
624
625 /*!
626 * \brief move assign operator
627 * \param other The source of assignment
628 * \return reference to self.
629 */
630 Array<T>& operator=(Array<T>&& other) {
631 data_ = std::move(other.data_);
632 return *this;
633 }
634
635 /*!
636 * \brief copy assign operator
637 * \param other The source of assignment
638 * \return reference to self.
639 */
640 Array<T>& operator=(const Array<T>& other) {
641 data_ = other.data_;
642 return *this;
643 }
644
645 public:
646 // iterators
647 struct ValueConverter {
648 using ResultType = T;
convertValueConverter649 static T convert(const ObjectRef& n) { return DowncastNoCheck<T>(n); }
650 };
651
652 using iterator = IterAdapter<ValueConverter, const ObjectRef*>;
653 using reverse_iterator = ReverseIterAdapter<ValueConverter, const ObjectRef*>;
654
655 /*! \return begin iterator */
begin()656 iterator begin() const { return iterator(GetArrayNode()->begin()); }
657
658 /*! \return end iterator */
end()659 iterator end() const { return iterator(GetArrayNode()->end()); }
660
661 /*! \return rbegin iterator */
rbegin()662 reverse_iterator rbegin() const {
663 // ArrayNode::end() is never nullptr
664 return reverse_iterator(GetArrayNode()->end() - 1);
665 }
666
667 /*! \return rend iterator */
rend()668 reverse_iterator rend() const {
669 // ArrayNode::begin() is never nullptr
670 return reverse_iterator(GetArrayNode()->begin() - 1);
671 }
672
673 public:
674 // const methods in std::vector
675 /*!
676 * \brief Immutably read i-th element from array.
677 * \param i The index
678 * \return the i-th element.
679 */
680 const T operator[](int64_t i) const {
681 ArrayNode* p = GetArrayNode();
682 CHECK(p != nullptr) << "ValueError: cannot index a null array";
683 CHECK(0 <= i && i < p->size_) << "IndexError: indexing " << i << " on an array of size "
684 << p->size_;
685 return DowncastNoCheck<T>(*(p->begin() + i));
686 }
687
688 /*! \return The size of the array */
size()689 size_t size() const {
690 ArrayNode* p = GetArrayNode();
691 return p == nullptr ? 0 : GetArrayNode()->size_;
692 }
693
694 /*! \return The capacity of the array */
capacity()695 size_t capacity() const {
696 ArrayNode* p = GetArrayNode();
697 return p == nullptr ? 0 : GetArrayNode()->capacity_;
698 }
699
700 /*! \return Whether array is empty */
empty()701 bool empty() const { return size() == 0; }
702
703 /*! \return The first element of the array */
front()704 const T front() const {
705 ArrayNode* p = GetArrayNode();
706 CHECK(p != nullptr) << "ValueError: cannot index a null array";
707 CHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array";
708 return DowncastNoCheck<T>(*(p->begin()));
709 }
710
711 /*! \return The last element of the array */
back()712 const T back() const {
713 ArrayNode* p = GetArrayNode();
714 CHECK(p != nullptr) << "ValueError: cannot index a null array";
715 CHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array";
716 return DowncastNoCheck<T>(*(p->end() - 1));
717 }
718
719 public:
720 // mutation in std::vector, implements copy-on-write
721
722 /*!
723 * \brief push a new item to the back of the list
724 * \param item The item to be pushed.
725 */
push_back(const T & item)726 void push_back(const T& item) {
727 ArrayNode* p = CopyOnWrite(1);
728 p->EmplaceInit(p->size_++, item);
729 }
730
731 /*!
732 * \brief Insert an element into the given position
733 * \param position An iterator pointing to the insertion point
734 * \param val The element to insert
735 */
insert(iterator position,const T & val)736 void insert(iterator position, const T& val) {
737 CHECK(data_ != nullptr) << "ValueError: cannot insert a null array";
738 int64_t idx = std::distance(begin(), position);
739 int64_t size = GetArrayNode()->size_;
740 auto addr = CopyOnWrite(1) //
741 ->EnlargeBy(1) //
742 ->MoveElementsRight(idx + 1, idx, size) //
743 ->MutableBegin();
744 new (addr + idx) ObjectRef(val);
745 }
746
747 /*!
748 * \brief Insert a range of elements into the given position
749 * \param position An iterator pointing to the insertion point
750 * \param first The begin iterator of the range
751 * \param last The end iterator of the range
752 */
753 template <typename IterType>
insert(iterator position,IterType first,IterType last)754 void insert(iterator position, IterType first, IterType last) {
755 if (first == last) {
756 return;
757 }
758 CHECK(data_ != nullptr) << "ValueError: cannot insert a null array";
759 int64_t idx = std::distance(begin(), position);
760 int64_t size = GetArrayNode()->size_;
761 int64_t numel = std::distance(first, last);
762 CopyOnWrite(numel)
763 ->EnlargeBy(numel)
764 ->MoveElementsRight(idx + numel, idx, size)
765 ->InitRange(idx, first, last);
766 }
767
768 /*! \brief Remove the last item of the list */
pop_back()769 void pop_back() {
770 CHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null";
771 int64_t size = GetArrayNode()->size_;
772 CHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty";
773 CopyOnWrite()->ShrinkBy(1);
774 }
775
776 /*!
777 * \brief Erase an element on the given position
778 * \param position An iterator pointing to the element to be erased
779 */
erase(iterator position)780 void erase(iterator position) {
781 CHECK(data_ != nullptr) << "ValueError: cannot erase a null array";
782 int64_t st = std::distance(begin(), position);
783 int64_t size = GetArrayNode()->size_;
784 CHECK(0 <= st && st < size) << "ValueError: cannot erase at index " << st
785 << ", because Array size is " << size;
786 CopyOnWrite() //
787 ->MoveElementsLeft(st, st + 1, size) //
788 ->ShrinkBy(1);
789 }
790
791 /*!
792 * \brief Erase a given range of elements
793 * \param first The begin iterator of the range
794 * \param last The end iterator of the range
795 */
erase(iterator first,iterator last)796 void erase(iterator first, iterator last) {
797 if (first == last) {
798 return;
799 }
800 CHECK(data_ != nullptr) << "ValueError: cannot erase a null array";
801 int64_t size = GetArrayNode()->size_;
802 int64_t st = std::distance(begin(), first);
803 int64_t ed = std::distance(begin(), last);
804 CHECK_LT(st, ed) << "ValueError: cannot erase array in range [" << st << ", " << ed << ")";
805 CHECK(0 <= st && st <= size && 0 <= ed && ed <= size)
806 << "ValueError: cannot erase array in range [" << st << ", " << ed << ")"
807 << ", because array size is " << size;
808 CopyOnWrite() //
809 ->MoveElementsLeft(st, ed, size) //
810 ->ShrinkBy(ed - st);
811 }
812
813 /*!
814 * \brief Resize the array.
815 * \param n The new size.
816 */
resize(int64_t n)817 void resize(int64_t n) {
818 CHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size";
819 if (data_ == nullptr) {
820 SwitchContainer(n);
821 return;
822 }
823 int64_t size = GetArrayNode()->size_;
824 if (size < n) {
825 CopyOnWrite(n - size)->EnlargeBy(n - size);
826 } else if (size > n) {
827 CopyOnWrite()->ShrinkBy(size - n);
828 }
829 }
830
831 /*!
832 * \brief Make sure the list has the capacity of at least n
833 * \param n lower bound of the capacity
834 */
reserve(int64_t n)835 void reserve(int64_t n) {
836 if (data_ == nullptr || n > GetArrayNode()->capacity_) {
837 SwitchContainer(n);
838 }
839 }
840
841 /*! \brief Release reference to all the elements */
clear()842 void clear() {
843 if (data_ != nullptr) {
844 ArrayNode* p = CopyOnWrite();
845 p->clear();
846 }
847 }
848
849 public:
850 // Array's own methods
851
852 /*!
853 * \brief set i-th element of the array.
854 * \param i The index
855 * \param value The value to be setted.
856 */
Set(int64_t i,T value)857 void Set(int64_t i, T value) {
858 ArrayNode* p = this->CopyOnWrite();
859 CHECK(0 <= i && i < p->size_) << "IndexError: indexing " << i << " on an array of size "
860 << p->size_;
861 *(p->MutableBegin() + i) = std::move(value);
862 }
863
864 /*! \return The underlying ArrayNode */
GetArrayNode()865 ArrayNode* GetArrayNode() const { return static_cast<ArrayNode*>(data_.get()); }
866
867 /*!
868 * \brief Helper function to apply fmutate to mutate an array.
869 * \param fmutate The transformation function T -> T.
870 * \tparam F the type of the mutation function.
871 * \note This function performs copy on write optimization.
872 */
873 template <typename F>
MutateByApply(F fmutate)874 void MutateByApply(F fmutate) {
875 if (data_ == nullptr) {
876 return;
877 }
878 struct StackFrame {
879 ArrayNode* p;
880 ObjectRef* itr;
881 int64_t i;
882 int64_t size;
883 };
884 std::unique_ptr<StackFrame> s = std::make_unique<StackFrame>();
885 s->p = GetArrayNode();
886 s->itr = s->p->MutableBegin();
887 s->i = 0;
888 s->size = s->p->size_;
889 if (!data_.unique()) {
890 // Loop invariant: keeps iterating when
891 // 1) data is not unique
892 // 2) no elements are actually mutated yet
893 for (; s->i < s->size; ++s->i, ++s->itr) {
894 T new_elem = fmutate(DowncastNoCheck<T>(*s->itr));
895 // do nothing when there is no mutation
896 if (new_elem.same_as(*s->itr)) {
897 continue;
898 }
899 // loop invariant breaks when the first real mutation happens
900 // we copy the elements into a new unique array
901 ObjectPtr<ArrayNode> copy = ArrayNode::CopyFrom(s->p->capacity_, s->p);
902 s->itr = copy->MutableBegin() + (s->i++);
903 *s->itr++ = std::move(new_elem);
904 data_ = std::move(copy);
905 // make sure `data_` is unique and break
906 break;
907 }
908 }
909 // when execution comes to this line, it is guaranteed that either
910 // 1) i == size
911 // or 2) data_.unique() is true
912 for (; s->i < s->size; ++s->i, ++s->itr) {
913 *s->itr = std::move(fmutate(std::move(DowncastNoCheck<T>(std::move(*s->itr)))));
914 }
915 }
916
917 /*!
918 * \brief reset the array to content from iterator.
919 * \param first begin of iterator
920 * \param last end of iterator
921 * \tparam IterType The type of iterator
922 */
923 template <typename IterType>
Assign(IterType first,IterType last)924 void Assign(IterType first, IterType last) {
925 int64_t cap = std::distance(first, last);
926 CHECK_GE(cap, 0) << "ValueError: cannot construct an Array of negative size";
927 ArrayNode* p = GetArrayNode();
928 if (p != nullptr && data_.unique() && p->capacity_ >= cap) {
929 // do not have to make new space
930 p->clear();
931 } else {
932 // create new space
933 data_ = ArrayNode::Empty(cap);
934 p = GetArrayNode();
935 }
936 // To ensure exception safety, size is only incremented after the initialization succeeds
937 ObjectRef* itr = p->MutableBegin();
938 for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) {
939 new (itr) ObjectRef(*first);
940 }
941 }
942
943 /*!
944 * \brief Copy on write semantics
945 * Do nothing if current handle is the unique copy of the array.
946 * Otherwise make a new copy of the array to ensure the current handle
947 * hold a unique copy.
948 *
949 * \return Handle to the internal node container(which ganrantees to be unique)
950 */
CopyOnWrite()951 ArrayNode* CopyOnWrite() {
952 if (data_ == nullptr) {
953 return SwitchContainer(ArrayNode::kInitSize);
954 }
955 if (!data_.unique()) {
956 return SwitchContainer(capacity());
957 }
958 return static_cast<ArrayNode*>(data_.get());
959 }
960
961 /*! \brief specify container node */
962 using ContainerType = ArrayNode;
963
964 private:
965 /*!
966 * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements.
967 * \param reserve_extra Number of extra slots needed
968 * \return ArrayNode pointer to the unique copy
969 */
CopyOnWrite(int64_t reserve_extra)970 ArrayNode* CopyOnWrite(int64_t reserve_extra) {
971 ArrayNode* p = GetArrayNode();
972 if (p == nullptr) {
973 // necessary to get around the constexpr address issue before c++17
974 const int64_t kInitSize = ArrayNode::kInitSize;
975 return SwitchContainer(std::max(kInitSize, reserve_extra));
976 }
977 if (p->capacity_ >= p->size_ + reserve_extra) {
978 return CopyOnWrite();
979 }
980 int64_t cap = p->capacity_ * ArrayNode::kIncFactor;
981 cap = std::max(cap, p->size_ + reserve_extra);
982 return SwitchContainer(cap);
983 }
984
985 /*!
986 * \brief Move or copy the ArrayNode to new address with the given capacity
987 * \param capacity The capacity requirement of the new address
988 */
SwitchContainer(int64_t capacity)989 ArrayNode* SwitchContainer(int64_t capacity) {
990 if (data_ == nullptr) {
991 data_ = ArrayNode::Empty(capacity);
992 } else if (data_.unique()) {
993 data_ = ArrayNode::MoveFrom(capacity, GetArrayNode());
994 } else {
995 data_ = ArrayNode::CopyFrom(capacity, GetArrayNode());
996 }
997 return static_cast<ArrayNode*>(data_.get());
998 }
999 };
1000
1001 /*!
1002 * \brief Concat two Arrays.
1003 * \param lhs first Array to be concatenated.
1004 * \param rhs second Array to be concatenated.
1005 * \return The concatenated Array. Original Arrays are kept unchanged.
1006 */
1007 template <typename T,
1008 typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
Concat(Array<T> lhs,const Array<T> & rhs)1009 inline Array<T> Concat(Array<T> lhs, const Array<T>& rhs) {
1010 for (const auto& x : rhs) {
1011 lhs.push_back(x);
1012 }
1013 return std::move(lhs);
1014 }
1015
1016 // Specialize make_object<ArrayNode> to make sure it is correct.
1017 template <>
make_object()1018 inline ObjectPtr<ArrayNode> make_object() {
1019 return ArrayNode::Empty();
1020 }
1021
1022 /*! \brief An object representing a structure or enumeration. */
1023 class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
1024 public:
1025 /*! \brief The tag representing the constructor used. */
1026 int32_t tag;
1027 /*! \brief Number of fields in the ADT object. */
1028 uint32_t size;
1029 // The fields of the structure follows directly in memory.
1030
1031 static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT;
1032 static constexpr const char* _type_key = "runtime.ADT";
1033 TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);
1034
1035 private:
1036 /*!
1037 * \return The number of elements in the array.
1038 */
GetSize()1039 size_t GetSize() const { return size; }
1040
1041 /*!
1042 * \brief Initialize the elements in the array.
1043 *
1044 * \tparam Iterator Iterator type of the array.
1045 * \param begin The begin iterator.
1046 * \param end The end iterator.
1047 */
1048 template <typename Iterator>
Init(Iterator begin,Iterator end)1049 void Init(Iterator begin, Iterator end) {
1050 size_t num_elems = std::distance(begin, end);
1051 this->size = 0;
1052 auto it = begin;
1053 for (size_t i = 0; i < num_elems; ++i) {
1054 InplaceArrayBase::EmplaceInit(i, *it++);
1055 // Only increment size after the initialization succeeds
1056 this->size++;
1057 }
1058 }
1059
1060 friend class ADT;
1061 friend InplaceArrayBase<ADTObj, ObjectRef>;
1062 };
1063
1064 /*! \brief reference to algebraic data type objects. */
1065 class ADT : public ObjectRef {
1066 public:
1067 /*!
1068 * \brief construct an ADT object reference.
1069 * \param tag The tag of the ADT object.
1070 * \param fields The fields of the ADT object.
1071 * \return The constructed ADT object reference.
1072 */
ADT(int32_t tag,std::vector<ObjectRef> fields)1073 ADT(int32_t tag, std::vector<ObjectRef> fields) : ADT(tag, fields.begin(), fields.end()){};
1074
1075 /*!
1076 * \brief construct an ADT object reference.
1077 * \param tag The tag of the ADT object.
1078 * \param begin The begin iterator to the start of the fields array.
1079 * \param end The end iterator to the end of the fields array.
1080 * \return The constructed ADT object reference.
1081 */
1082 template <typename Iterator>
ADT(int32_t tag,Iterator begin,Iterator end)1083 ADT(int32_t tag, Iterator begin, Iterator end) {
1084 size_t num_elems = std::distance(begin, end);
1085 auto ptr = make_inplace_array_object<ADTObj, ObjectRef>(num_elems);
1086 ptr->tag = tag;
1087 ptr->Init(begin, end);
1088 data_ = std::move(ptr);
1089 }
1090
1091 /*!
1092 * \brief construct an ADT object reference.
1093 * \param tag The tag of the ADT object.
1094 * \param init The initializer list of fields.
1095 * \return The constructed ADT object reference.
1096 */
ADT(int32_t tag,std::initializer_list<ObjectRef> init)1097 ADT(int32_t tag, std::initializer_list<ObjectRef> init) : ADT(tag, init.begin(), init.end()){};
1098
1099 /*!
1100 * \brief Access element at index.
1101 *
1102 * \param idx The array index
1103 * \return const ObjectRef
1104 */
1105 const ObjectRef& operator[](size_t idx) const { return operator->()->operator[](idx); }
1106
1107 /*!
1108 * \brief Return the ADT tag.
1109 */
tag()1110 int32_t tag() const { return operator->()->tag; }
1111
1112 /*!
1113 * \brief Return the number of fields.
1114 */
size()1115 size_t size() const { return operator->()->size; }
1116
1117 /*!
1118 * \brief Construct a tuple object.
1119 *
1120 * \tparam Args Type params of tuple feilds.
1121 * \param args Tuple fields.
1122 * \return ADT The tuple object reference.
1123 */
1124 template <typename... Args>
Tuple(Args &&...args)1125 static ADT Tuple(Args&&... args) {
1126 return ADT(0, std::forward<Args>(args)...);
1127 }
1128
1129 TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj);
1130 };
1131
1132 /*! \brief An object representing string. It's POD type. */
1133 class StringObj : public Object {
1134 public:
1135 /*! \brief The pointer to string data. */
1136 const char* data;
1137
1138 /*! \brief The length of the string object. */
1139 uint64_t size;
1140
1141 static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString;
1142 static constexpr const char* _type_key = "runtime.String";
1143 TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object);
1144
1145 private:
1146 /*! \brief String object which is moved from std::string container. */
1147 class FromStd;
1148
1149 friend class String;
1150 };
1151
1152 /*!
1153 * \brief Reference to string objects.
1154 *
1155 * \code
1156 *
1157 * // Example to create runtime String reference object from std::string
1158 * std::string s = "hello world";
1159 *
1160 * // You can create the reference from existing std::string
1161 * String ref{std::move(s)};
1162 *
1163 * // You can rebind the reference to another string.
1164 * ref = std::string{"hello world2"};
1165 *
1166 * // You can use the reference as hash map key
1167 * std::unordered_map<String, int32_t> m;
1168 * m[ref] = 1;
1169 *
1170 * // You can compare the reference object with other string objects
1171 * assert(ref == "hello world", true);
1172 *
1173 * // You can convert the reference to std::string again
1174 * string s2 = (string)ref;
1175 *
1176 * \endcode
1177 */
1178 class String : public ObjectRef {
1179 public:
1180 /*!
1181 * \brief Construct an empty string.
1182 */
String()1183 String() : String(std::string()) {}
1184 /*!
1185 * \brief Construct a new String object
1186 *
1187 * \param other The moved/copied std::string object
1188 *
1189 * \note If user passes const reference, it will trigger copy. If it's rvalue,
1190 * it will be moved into other.
1191 */
1192 String(std::string other); // NOLINT(*)
1193
1194 /*!
1195 * \brief Construct a new String object
1196 *
1197 * \param other a char array.
1198 */
String(const char * other)1199 String(const char* other) // NOLINT(*)
1200 : String(std::string(other)) {}
1201
1202 /*!
1203 * \brief Change the value the reference object points to.
1204 *
1205 * \param other The value for the new String
1206 *
1207 */
1208 inline String& operator=(std::string other);
1209
1210 /*!
1211 * \brief Change the value the reference object points to.
1212 *
1213 * \param other The value for the new String
1214 */
1215 inline String& operator=(const char* other);
1216
1217 /*!
1218 * \brief Compares this String object to other
1219 *
1220 * \param other The String to compare with.
1221 *
1222 * \return zero if both char sequences compare equal. negative if this appear
1223 * before other, positive otherwise.
1224 */
compare(const String & other)1225 int compare(const String& other) const {
1226 return memncmp(data(), other.data(), size(), other.size());
1227 }
1228
1229 /*!
1230 * \brief Compares this String object to other
1231 *
1232 * \param other The string to compare with.
1233 *
1234 * \return zero if both char sequences compare equal. negative if this appear
1235 * before other, positive otherwise.
1236 */
compare(const std::string & other)1237 int compare(const std::string& other) const {
1238 return memncmp(data(), other.data(), size(), other.size());
1239 }
1240
1241 /*!
1242 * \brief Compares this to other
1243 *
1244 * \param other The character array to compare with.
1245 *
1246 * \return zero if both char sequences compare equal. negative if this appear
1247 * before other, positive otherwise.
1248 */
compare(const char * other)1249 int compare(const char* other) const {
1250 return memncmp(data(), other, size(), std::strlen(other));
1251 }
1252
1253 /*!
1254 * \brief Returns a pointer to the char array in the string.
1255 *
1256 * \return const char*
1257 */
c_str()1258 const char* c_str() const { return get()->data; }
1259
1260 /*!
1261 * \brief Return the length of the string
1262 *
1263 * \return size_t string length
1264 */
size()1265 size_t size() const {
1266 const auto* ptr = get();
1267 return ptr->size;
1268 }
1269
1270 /*!
1271 * \brief Return the length of the string
1272 *
1273 * \return size_t string length
1274 */
length()1275 size_t length() const { return size(); }
1276
1277 /*!
1278 * \brief Retun if the string is empty
1279 *
1280 * \return true if empty, false otherwise.
1281 */
empty()1282 bool empty() const { return size() == 0; }
1283
1284 /*!
1285 * \brief Return the data pointer
1286 *
1287 * \return const char* data pointer
1288 */
data()1289 const char* data() const { return get()->data; }
1290
1291 /*!
1292 * \brief Convert String to an std::string object
1293 *
1294 * \return std::string
1295 */
string()1296 operator std::string() const { return std::string{get()->data, size()}; }
1297
1298 // LLVM compatibility function, implemented in src/target/llvm/llvm_common.h
1299 /*!
1300 * \brief Convert String to an llvm::StringRef object
1301 *
1302 * \return llvm::StringRef
1303 */
1304 inline operator llvm::StringRef() const;
1305
1306 /*!
1307 * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String
1308 * \param val The value to be checked
1309 * \return A boolean indicating if val can be converted to String
1310 */
1311 inline static bool CanConvertFrom(const TVMArgValue& val);
1312
1313 /*!
1314 * \brief Hash the binary bytes
1315 * \param data The data pointer
1316 * \param size The size of the bytes.
1317 * \return the hash value.
1318 */
HashBytes(const char * data,size_t size)1319 static size_t HashBytes(const char* data, size_t size) {
1320 // This function falls back to string copy with c++11 compiler and is
1321 // recommended to be compiled with c++14
1322 #if TVM_USE_CXX17_STRING_VIEW_HASH
1323 return std::hash<std::string_view>()(std::string_view(data, size));
1324 #elif TVM_USE_CXX14_STRING_VIEW_HASH
1325 return std::hash<std::experimental::string_view>()(std::experimental::string_view(data, size));
1326 #else
1327 return std::hash<std::string>()(std::string(data, size));
1328 #endif
1329 }
1330
1331 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
1332
1333 private:
1334 /*!
1335 * \brief Compare two char sequence
1336 *
1337 * \param lhs Pointers to the char array to compare
1338 * \param rhs Pointers to the char array to compare
1339 * \param lhs_count Length of the char array to compare
1340 * \param rhs_count Length of the char array to compare
1341 * \return int zero if both char sequences compare equal. negative if this
1342 * appear before other, positive otherwise.
1343 */
1344 static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count);
1345
1346 /*!
1347 * \brief Concatenate two char sequences
1348 *
1349 * \param lhs Pointers to the lhs char array
1350 * \param lhs_size The size of the lhs char array
1351 * \param rhs Pointers to the rhs char array
1352 * \param rhs_size The size of the rhs char array
1353 *
1354 * \return The concatenated char sequence
1355 */
Concat(const char * lhs,size_t lhs_size,const char * rhs,size_t rhs_size)1356 static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) {
1357 std::string ret(lhs, lhs_size);
1358 ret.append(rhs, rhs_size);
1359 return String(ret);
1360 }
1361
1362 // Overload + operator
1363 friend String operator+(const String& lhs, const String& rhs);
1364 friend String operator+(const String& lhs, const std::string& rhs);
1365 friend String operator+(const std::string& lhs, const String& rhs);
1366 friend String operator+(const String& lhs, const char* rhs);
1367 friend String operator+(const char* lhs, const String& rhs);
1368
1369 friend struct tvm::runtime::ObjectEqual;
1370 };
1371
1372 /*! \brief An object representing string moved from std::string. */
1373 class StringObj::FromStd : public StringObj {
1374 public:
1375 /*!
1376 * \brief Construct a new FromStd object
1377 *
1378 * \param other The moved/copied std::string object
1379 *
1380 * \note If user passes const reference, it will trigger copy. If it's rvalue,
1381 * it will be moved into other.
1382 */
FromStd(std::string other)1383 explicit FromStd(std::string other) : data_container{other} {}
1384
1385 private:
1386 /*! \brief Container that holds the memory. */
1387 std::string data_container;
1388
1389 friend class String;
1390 };
1391
String(std::string other)1392 inline String::String(std::string other) {
1393 auto ptr = make_object<StringObj::FromStd>(std::move(other));
1394 ptr->size = ptr->data_container.size();
1395 ptr->data = ptr->data_container.data();
1396 data_ = std::move(ptr);
1397 }
1398
1399 inline String& String::operator=(std::string other) {
1400 String replace{std::move(other)};
1401 data_.swap(replace.data_);
1402 return *this;
1403 }
1404
1405 inline String& String::operator=(const char* other) { return operator=(std::string(other)); }
1406
1407 inline String operator+(const String& lhs, const String& rhs) {
1408 size_t lhs_size = lhs.size();
1409 size_t rhs_size = rhs.size();
1410 return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
1411 }
1412
1413 inline String operator+(const String& lhs, const std::string& rhs) {
1414 size_t lhs_size = lhs.size();
1415 size_t rhs_size = rhs.size();
1416 return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
1417 }
1418
1419 inline String operator+(const std::string& lhs, const String& rhs) {
1420 size_t lhs_size = lhs.size();
1421 size_t rhs_size = rhs.size();
1422 return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
1423 }
1424
1425 inline String operator+(const char* lhs, const String& rhs) {
1426 size_t lhs_size = std::strlen(lhs);
1427 size_t rhs_size = rhs.size();
1428 return String::Concat(lhs, lhs_size, rhs.data(), rhs_size);
1429 }
1430
1431 inline String operator+(const String& lhs, const char* rhs) {
1432 size_t lhs_size = lhs.size();
1433 size_t rhs_size = std::strlen(rhs);
1434 return String::Concat(lhs.data(), lhs_size, rhs, rhs_size);
1435 }
1436
1437 // Overload < operator
1438 inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; }
1439
1440 inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; }
1441
1442 inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; }
1443
1444 inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; }
1445
1446 inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; }
1447
1448 // Overload > operator
1449 inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; }
1450
1451 inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; }
1452
1453 inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; }
1454
1455 inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; }
1456
1457 inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; }
1458
1459 // Overload <= operator
1460 inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; }
1461
1462 inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; }
1463
1464 inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; }
1465
1466 inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; }
1467
1468 inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; }
1469
1470 // Overload >= operator
1471 inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; }
1472
1473 inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; }
1474
1475 inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; }
1476
1477 inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; }
1478
1479 inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; }
1480
1481 // Overload == operator
1482 inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; }
1483
1484 inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; }
1485
1486 inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; }
1487
1488 inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; }
1489
1490 inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; }
1491
1492 // Overload != operator
1493 inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; }
1494
1495 inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; }
1496
1497 inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; }
1498
1499 inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; }
1500
1501 inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; }
1502
1503 inline std::ostream& operator<<(std::ostream& out, const String& input) {
1504 out.write(input.data(), input.size());
1505 return out;
1506 }
1507
memncmp(const char * lhs,const char * rhs,size_t lhs_count,size_t rhs_count)1508 inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
1509 if (lhs == rhs && lhs_count == rhs_count) return 0;
1510
1511 for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) {
1512 if (lhs[i] < rhs[i]) return -1;
1513 if (lhs[i] > rhs[i]) return 1;
1514 }
1515 if (lhs_count < rhs_count) {
1516 return -1;
1517 } else if (lhs_count > rhs_count) {
1518 return 1;
1519 } else {
1520 return 0;
1521 }
1522 }
1523
operator()1524 inline size_t ObjectHash::operator()(const ObjectRef& a) const {
1525 if (const auto* str = a.as<StringObj>()) {
1526 return String::HashBytes(str->data, str->size);
1527 }
1528 return ObjectPtrHash()(a);
1529 }
1530
operator()1531 inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) const {
1532 if (a.same_as(b)) {
1533 return true;
1534 }
1535 if (const auto* str_a = a.as<StringObj>()) {
1536 if (const auto* str_b = b.as<StringObj>()) {
1537 return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0;
1538 }
1539 }
1540 return false;
1541 }
1542
1543 /*! \brief Helper to represent nullptr for optional. */
1544 struct NullOptType {};
1545
1546 /*!
1547 * \brief Optional container that to represent to a Nullable variant of T.
1548 * \tparam T The original ObjectRef.
1549 *
1550 * \code
1551 *
1552 * Optional<String> opt0 = nullptr;
1553 * Optional<String> opt1 = String("xyz");
1554 * CHECK(opt0 == nullptr);
1555 * CHECK(opt1 == "xyz");
1556 *
1557 * \endcode
1558 */
1559 template <typename T>
1560 class Optional : public ObjectRef {
1561 public:
1562 using ContainerType = typename T::ContainerType;
1563 static_assert(std::is_base_of<ObjectRef, T>::value, "Optional is only defined for ObjectRef.");
1564 // default constructors.
1565 Optional() = default;
1566 Optional(const Optional<T>&) = default;
1567 Optional(Optional<T>&&) = default;
1568 Optional<T>& operator=(const Optional<T>&) = default;
1569 Optional<T>& operator=(Optional<T>&&) = default;
1570 /*!
1571 * \brief Construct from an ObjectPtr
1572 * whose type already matches the ContainerType.
1573 * \param ptr
1574 */
Optional(ObjectPtr<Object> ptr)1575 explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
1576 /*! \brief Nullopt handling */
Optional(NullOptType)1577 Optional(NullOptType) {} // NOLINT(*)
1578 // nullptr handling.
1579 // disallow implicit conversion as 0 can be implicitly converted to nullptr_t
Optional(std::nullptr_t)1580 explicit Optional(std::nullptr_t) {}
1581 Optional<T>& operator=(std::nullptr_t) {
1582 data_ = nullptr;
1583 return *this;
1584 }
1585 // normal value handling.
Optional(T other)1586 Optional(T other) // NOLINT(*)
1587 : ObjectRef(std::move(other)) {}
1588 Optional<T>& operator=(T other) {
1589 ObjectRef::operator=(std::move(other));
1590 return *this;
1591 }
1592 // delete the int constructor
1593 // since Optional<Integer>(0) is ambiguious
1594 // 0 can be implicitly casted to nullptr_t
1595 explicit Optional(int val) = delete;
1596 Optional<T>& operator=(int val) = delete;
1597 /*!
1598 * \return A not-null container value in the optional.
1599 * \note This function performs not-null checking.
1600 */
value()1601 T value() const {
1602 CHECK(data_ != nullptr);
1603 return T(data_);
1604 }
1605 /*!
1606 * \return The contained value if the Optional is not null
1607 * otherwise return the default_value.
1608 */
value_or(T default_value)1609 T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; }
1610
1611 /*! \return Whether the container is not nullptr.*/
1612 explicit operator bool() const { return *this != nullptr; }
1613 // operator overloadings
1614 bool operator==(std::nullptr_t) const { return data_ == nullptr; }
1615 bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
1616 auto operator==(const Optional<T>& other) const {
1617 // support case where sub-class returns a symbolic ref type.
1618 using RetType = decltype(value() == other.value());
1619 if (same_as(other)) return RetType(true);
1620 if (*this != nullptr && other != nullptr) {
1621 return value() == other.value();
1622 } else {
1623 // one of them is nullptr.
1624 return RetType(false);
1625 }
1626 }
1627 auto operator!=(const Optional<T>& other) const {
1628 // support case where sub-class returns a symbolic ref type.
1629 using RetType = decltype(value() != other.value());
1630 if (same_as(other)) return RetType(false);
1631 if (*this != nullptr && other != nullptr) {
1632 return value() != other.value();
1633 } else {
1634 // one of them is nullptr.
1635 return RetType(true);
1636 }
1637 }
1638 auto operator==(const T& other) const {
1639 using RetType = decltype(value() == other);
1640 if (same_as(other)) return RetType(true);
1641 if (*this != nullptr) return value() == other;
1642 return RetType(false);
1643 }
1644 auto operator!=(const T& other) const { return !(*this == other); }
1645 template <typename U>
1646 auto operator==(const U& other) const {
1647 using RetType = decltype(value() == other);
1648 if (*this == nullptr) return RetType(false);
1649 return value() == other;
1650 }
1651 template <typename U>
1652 auto operator!=(const U& other) const {
1653 using RetType = decltype(value() != other);
1654 if (*this == nullptr) return RetType(true);
1655 return value() != other;
1656 }
1657 static constexpr bool _type_is_nullable = true;
1658 };
1659
1660 /*!
1661 * \brief An object representing a closure. This object is used by both the
1662 * Relay VM and interpreter.
1663 */
1664 class ClosureObj : public Object {
1665 public:
1666 static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure;
1667 static constexpr const char* _type_key = "runtime.Closure";
1668 TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
1669 };
1670
1671 /*! \brief reference to closure. */
1672 class Closure : public ObjectRef {
1673 public:
1674 TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
1675 };
1676
1677 } // namespace runtime
1678
1679 // expose the functions to the root namespace.
1680 using runtime::Optional;
1681 using runtime::String;
1682 constexpr runtime::NullOptType NullOpt{};
1683 } // namespace tvm
1684
1685 namespace std {
1686
1687 template <>
1688 struct hash<::tvm::runtime::String> {
1689 std::size_t operator()(const ::tvm::runtime::String& str) const {
1690 return ::tvm::runtime::String::HashBytes(str.data(), str.size());
1691 }
1692 };
1693 } // namespace std
1694
1695 #endif // TVM_RUNTIME_CONTAINER_H_
1696