1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*!
20  * \file tvm/ir/attrs.h
21  * \brief Helpers for attribute objects.
22  *
23  *  This module enables declaration of named attributes
24  *  which support default value setup and bound checking.
25  *
26  * \code
27  *   struct MyAttrs : public tvm::AttrsNode<MyAttrs> {
28  *     float learning_rate;
29  *     int num_hidden;
30  *     String name;
31  *     // declare attribute fields in header file
32  *     TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") {
33  *       TVM_ATTR_FIELD(num_hidden).set_lower_bound(1);
34  *       TVM_ATTR_FIELD(learning_rate).set_default(0.01f);
35  *       TVM_ATTR_FIELD(name).set_default("hello");
36  *     }
37  *   };
38  *   // register it in cc file
39  *   TVM_REGISTER_NODE_TYPE(MyAttrs);
40  * \endcode
41  *
42  * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD
43  */
44 #ifndef TVM_IR_ATTRS_H_
45 #define TVM_IR_ATTRS_H_
46 
47 #include <dmlc/common.h>
48 #include <tvm/ir/expr.h>
49 #include <tvm/node/structural_equal.h>
50 #include <tvm/node/structural_hash.h>
51 #include <tvm/runtime/packed_func.h>
52 
53 #include <functional>
54 #include <string>
55 #include <type_traits>
56 #include <unordered_map>
57 #include <utility>
58 #include <vector>
59 
60 namespace tvm {
61 /*!
62  * \brief Declare an attribute function.
63  * \param ClassName The name of the class.
64  * \param TypeKey The type key to be used by the TVM node system.
65  */
66 #define TVM_DECLARE_ATTRS(ClassName, TypeKey)                    \
67   static constexpr const char* _type_key = TypeKey;              \
68   TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
69   template <typename FVisit>                                     \
70   void __VisitAttrs__(FVisit& __fvisit__)  // NOLINT(*)
71 
72 /*!
73  * \brief Declare an attribute field.
74  * \param FieldName The field name.
75  */
76 #define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName)
77 
78 /*!
79  * \brief Create a NodeRef type that represents null.
80  * \tparam TNodeRef the type to be created.
81  * \return A instance that will represent None.
82  */
83 template <typename TObjectRef>
NullValue()84 inline TObjectRef NullValue() {
85   static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types");
86   return TObjectRef(ObjectPtr<Object>(nullptr));
87 }
88 
89 template <>
90 inline DataType NullValue<DataType>() {
91   return DataType(DataType::kHandle, 0, 0);
92 }
93 
94 /*! \brief Error thrown during attribute checking. */
95 struct AttrError : public dmlc::Error {
96   /*!
97    * \brief constructor
98    * \param msg error message
99    */
AttrErrorAttrError100   explicit AttrError(std::string msg) : dmlc::Error("AttributeError:" + msg) {}
101 };
102 
103 /*!
104  * \brief Information about attribute fields in string representations.
105  */
106 class AttrFieldInfoNode : public Object {
107  public:
108   /*! \brief name of the field */
109   String name;
110   /*! \brief type docstring information in str. */
111   String type_info;
112   /*! \brief detailed description of the type */
113   String description;
114 
VisitAttrs(AttrVisitor * v)115   void VisitAttrs(AttrVisitor* v) {
116     v->Visit("name", &name);
117     v->Visit("type_info", &type_info);
118     v->Visit("description", &description);
119   }
120 
121   static constexpr const char* _type_key = "AttrFieldInfo";
122   static constexpr bool _type_has_method_sequal_reduce = false;
123   static constexpr bool _type_has_method_shash_reduce = false;
124   TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
125 };
126 
127 /*! \brief AttrFieldInfo */
128 class AttrFieldInfo : public ObjectRef {
129  public:
130   TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
131 };
132 
133 /*!
134  * \brief Base class of all attribute class
135  * \note Do not subclass AttrBaseNode directly,
136  *       subclass AttrsNode instead.
137  * \sa AttrsNode
138  */
139 class BaseAttrsNode : public Object {
140  public:
141   using TVMArgs = runtime::TVMArgs;
142   using TVMRetValue = runtime::TVMRetValue;
143   /*! \brief virtual destructor */
~BaseAttrsNode()144   virtual ~BaseAttrsNode() {}
145   // visit function
VisitAttrs(AttrVisitor * v)146   virtual void VisitAttrs(AttrVisitor* v) {}
147   /*!
148    * \brief Initialize the attributes by sequence of arguments
149    * \param args The postional arguments in the form
150    *        [key0, value0, key1, value1, ..., key_n, value_n]
151    */
152   template <typename... Args>
153   inline void InitBySeq(Args&&... args);
154   /*!
155    * \brief Print readible docstring to ostream, add newline.
156    * \param os the stream to print the docstring to.
157    */
158   inline void PrintDocString(std::ostream& os) const;  // NOLINT(*)
159   /*!
160    * \brief Visit attributes that do not equal the default value.
161    *
162    * \note This is useful to extract fields for concise printing.
163    * \param v The visitor
164    */
165   TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
166   /*!
167    * \brief Get the field information
168    * \return The fields in the Attrs.
169    */
170   TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
171   /*!
172    * \brief Initialize the attributes by arguments.
173    * \param kwargs The key value pairs for initialization.
174    *        [key0, value0, key1, value1, ..., key_n, value_n]
175    * \param allow_unknown Whether allow additional unknown fields.
176    * \note This function throws when the required field is not present.
177    */
178   TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
179 
180   static constexpr const bool _type_has_method_sequal_reduce = true;
181   static constexpr const bool _type_has_method_shash_reduce = true;
182   static constexpr const char* _type_key = "Attrs";
183   TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
184 };
185 
186 /*!
187  * \brief Managed reference to BaseAttrsNode.
188  * \sa AttrsNode, BaseAttrsNode
189  */
190 class Attrs : public ObjectRef {
191  public:
192   TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode);
193 };
194 
195 /*!
196  * \brief Specialized attribute type that is backed by a map.
197  *  The DictAttrsNode implements the Attrs behavior,
198  *  its fields are directly accessible via object.field_name
199  *  like other normal nodes.
200  */
201 class DictAttrsNode : public BaseAttrsNode {
202  public:
203   /*! \brief internal attrs map */
204   Map<String, ObjectRef> dict;
205 
SEqualReduce(const DictAttrsNode * other,SEqualReducer equal)206   bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
207     return equal(dict, other->dict);
208   }
209 
SHashReduce(SHashReducer hash_reduce)210   void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); }
211 
212   // implementations
213   void VisitAttrs(AttrVisitor* v) final;
214   void VisitNonDefaultAttrs(AttrVisitor* v) final;
215   void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
216   Array<AttrFieldInfo> ListFieldInfo() const final;
217   // type info
218   static constexpr const char* _type_key = "DictAttrs";
219   TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
220 };
221 
222 /*!
223  * \brief Managed reference to DictAttrsNode
224  * \sa DictAttrsNode.
225  */
226 class DictAttrs : public Attrs {
227  public:
228   /*!
229    * \brief Consruct a Attrs backed by DictAttrsNode.
230    * \param dict The attributes.
231    * \return The dict attributes.
232    */
233   TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);
234 
235   TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
236   TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
237 };
238 
239 /*!
240  * \brief Create an Attr object with all default values.
241  * \tparam TAttrNode the type to be created.
242  * \return A instance that will represent None.
243  */
244 template <typename TAttrs>
AttrsWithDefaultValues()245 inline TAttrs AttrsWithDefaultValues() {
246   static_assert(std::is_base_of<Attrs, TAttrs>::value, "Can only take attr nodes");
247   auto n = make_object<typename TAttrs::ContainerType>();
248   n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false);
249   return TAttrs(n);
250 }
251 
252 // Namespace containing detail implementations
253 namespace detail {
254 using runtime::TVMArgValue;
255 
256 // helper entry that does nothing in set_default/bound/describe calls.
257 struct AttrNopEntry {
258   using TSelf = AttrNopEntry;
259 
describeAttrNopEntry260   TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
261   template <typename T>
set_defaultAttrNopEntry262   TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
263     return *this;
264   }
265   template <typename T>
set_lower_boundAttrNopEntry266   TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
267     return *this;
268   }
269   template <typename T>
set_upper_boundAttrNopEntry270   TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
271     return *this;
272   }
273 };
274 
275 // Wrapper for normal visitor.
276 class AttrNormalVisitor {
277  public:
AttrNormalVisitor(AttrVisitor * visitor)278   explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
279   template <typename T>
operator()280   AttrNopEntry operator()(const char* key, T* value) {
281     visitor_->Visit(key, value);
282     return AttrNopEntry();
283   }
284 
285  private:
286   AttrVisitor* visitor_;
287 };
288 
289 class AttrsSEqualVisitor {
290  public:
291   bool result_{true};
292   // constructor
AttrsSEqualVisitor(const Object * lhs,const Object * rhs,const SEqualReducer & equal)293   AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
294       : lhs_(lhs), rhs_(rhs), equal_(equal) {}
295   template <typename T>
operator()296   AttrNopEntry operator()(const char* key, T* lhs_value) {
297     if (!result_) return AttrNopEntry();
298     const T* rhs_value = reinterpret_cast<const T*>(
299         reinterpret_cast<const char*>(rhs_) +
300         (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_)));
301     if (!equal_(*lhs_value, *rhs_value)) {
302       result_ = false;
303     }
304     return AttrNopEntry();
305   }
306 
307  private:
308   const Object* lhs_;
309   const Object* rhs_;
310   const SEqualReducer& equal_;
311 };
312 
313 class AttrsSHashVisitor {
314  public:
AttrsSHashVisitor(const SHashReducer & hash_reducer)315   explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {}
316 
317   template <typename T>
operator()318   AttrNopEntry operator()(const char* key, T* value) {
319     hash_reducer_(*value);
320     return AttrNopEntry();
321   }
322 
323  private:
324   const SHashReducer& hash_reducer_;
325 };
326 
327 // helper entry that does initialization, set default.
328 template <typename T>
329 struct AttrInitEntry {
330   // The attributes
331   using TSelf = AttrInitEntry<T>;
332   // The type key
333   const char* type_key_;
334   // field name
335   const char* key_;
336   // internal value.
337   T* value_;
338   // whether the value is missing.
339   bool value_missing_{true};
340 
341   AttrInitEntry() = default;
342 
AttrInitEntryAttrInitEntry343   AttrInitEntry(AttrInitEntry&& other) {
344     type_key_ = other.type_key_;
345     key_ = other.key_;
346     value_ = other.value_;
347     value_missing_ = other.value_missing_;
348     // avoid unexpected throw
349     other.value_missing_ = false;
350   }
351 
352   // If the value is still missing in destruction time throw an error.
~AttrInitEntryAttrInitEntry353   ~AttrInitEntry() DMLC_THROW_EXCEPTION {
354     if (value_missing_) {
355       std::ostringstream os;
356       os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization."
357          << "If the key is defined check that its type matches the declared type.";
358       throw AttrError(os.str());
359     }
360   }
361   // override fields.
362   // This function sets the lower bound of the attribute
set_lower_boundAttrInitEntry363   TSelf& set_lower_bound(const T& begin) {
364     if (this->value_missing_) return *this;
365     const T& val = *value_;
366     if (begin > val) {
367       std::ostringstream os;
368       os << type_key_ << "." << key_ << ": "
369          << "value " << val << " is smaller than the lower bound " << begin;
370       throw AttrError(os.str());
371     }
372     return *this;
373   }
374   // This function sets the upper bound of the attribute
set_upper_boundAttrInitEntry375   TSelf& set_upper_bound(const T& end) {
376     if (this->value_missing_) return *this;
377     const T& val = *value_;
378     if (val > end) {
379       std::ostringstream os;
380       os << type_key_ << "." << key_ << ": "
381          << "value " << val << " is bigger than the upper bound " << end;
382       throw AttrError(os.str());
383     }
384     return *this;
385   }
386   // set default when
set_defaultAttrInitEntry387   TSelf& set_default(const T& value) {
388     if (!value_missing_) return *this;
389     *value_ = value;
390     value_missing_ = false;
391     return *this;
392   }
describeAttrInitEntry393   TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
394 };
395 
396 // Template function to allow smart conversion
397 // from Expr types into the constants.
398 template <typename T>
SetValue(T * ptr,const TVMArgValue & val)399 inline void SetValue(T* ptr, const TVMArgValue& val) {
400   *ptr = val.operator T();
401 }
402 
403 template <typename T>
SetIntValue(T * ptr,const TVMArgValue & val)404 inline void SetIntValue(T* ptr, const TVMArgValue& val) {
405   if (val.type_code() == kDLInt) {
406     *ptr = static_cast<T>(val.value().v_int64);
407   } else {
408     IntImm expr = val;
409     *ptr = static_cast<T>(expr->value);
410   }
411 }
412 
413 template <>
414 inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
415   if (String::CanConvertFrom(val)) {
416     *ptr = val.operator std::string();
417   } else {
418     LOG(FATAL) << "Expect str";
419   }
420 }
421 
422 template <>
423 inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
424   if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
425     *ptr = val.operator double();
426   } else {
427     ObjectRef expr = val;
428     CHECK(expr.defined());
429     if (const IntImmNode* op = expr.as<IntImmNode>()) {
430       *ptr = static_cast<double>(op->value);
431     } else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
432       *ptr = static_cast<double>(op->value);
433     } else {
434       LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
435     }
436   }
437 }
438 template <>
439 inline void SetValue<int>(int* ptr, const TVMArgValue& val) {
440   SetIntValue(ptr, val);
441 }
442 template <>
443 inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) {
444   SetIntValue(ptr, val);
445 }
446 template <>
447 inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) {
448   SetIntValue(ptr, val);
449 }
450 template <>
451 inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) {
452   SetIntValue(ptr, val);
453 }
454 
455 // Visitor for value initialization
456 template <typename FFind>
457 class AttrInitVisitor {
458  public:
459   // Counter of number of matched attributes during visit.
460   // This is used to decide if there is additional unmatched attributes.
461   size_t hit_count_{0};
462   // constructor
AttrInitVisitor(const char * type_key,FFind ffind)463   AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
464 
465   template <typename T>
operator()466   AttrInitEntry<T> operator()(const char* key, T* value) {
467     TVMArgValue val;
468     AttrInitEntry<T> opt;
469     opt.type_key_ = type_key_;
470     opt.key_ = key;
471     opt.value_ = value;
472     if (ffind_(key, &val)) {
473       SetValue(value, val);
474       opt.value_missing_ = false;
475       ++hit_count_;
476     } else {
477       opt.value_missing_ = true;
478     }
479 #if defined(__GNUC__)
480 #pragma GCC diagnostic ignored "-Wpragmas"
481 #pragma GCC diagnostic ignored "-Wpessimizing-move"
482 #endif
483     return std::move(opt);
484   }
485 
486  private:
487   // the type key
488   const char* type_key_;
489   FFind ffind_;
490 };
491 
492 template <typename FFind>
CreateInitVisitor(const char * type_key,FFind ffind)493 inline AttrInitVisitor<FFind> CreateInitVisitor(const char* type_key, FFind ffind) {
494   return AttrInitVisitor<FFind>(type_key, ffind);
495 }
496 
497 /*!
498  * \brief Helper struct to get the type name known to tvm.
499  * \tparam T the type we are interested in.
500  */
501 template <typename T>
502 struct TypeName {
503   static constexpr const char* value = T::ContainerType::_type_key;
504 };
505 
506 template <>
507 struct TypeName<int> {
508   static constexpr const char* value = "int";
509 };
510 
511 template <>
512 struct TypeName<int64_t> {
513   static constexpr const char* value = "int64";
514 };
515 
516 template <>
517 struct TypeName<uint64_t> {
518   static constexpr const char* value = "uint64_t";
519 };
520 
521 template <>
522 struct TypeName<DataType> {
523   static constexpr const char* value = "DataType";
524 };
525 
526 template <>
527 struct TypeName<std::string> {
528   static constexpr const char* value = "str";
529 };
530 
531 template <>
532 struct TypeName<bool> {
533   static constexpr const char* value = "bool";
534 };
535 
536 template <>
537 struct TypeName<void*> {
538   static constexpr const char* value = "handle";
539 };
540 
541 template <>
542 struct TypeName<double> {
543   static constexpr const char* value = "double";
544 };
545 
546 class AttrDocEntry {
547  public:
548   using TSelf = AttrDocEntry;
549 
550   explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info) : info_(info) {}
551   TSelf& describe(const char* str) {
552     info_->description = str;
553     return *this;
554   }
555   template <typename T>
556   TSelf& set_default(const T& value) {
557     std::ostringstream os;
558     os << info_->type_info << ", default=" << value;
559     info_->type_info = os.str();
560     return *this;
561   }
562   template <typename T>
563   TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
564     return *this;
565   }
566   template <typename T>
567   TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
568     return *this;
569   }
570 
571  private:
572   ObjectPtr<AttrFieldInfoNode> info_;
573 };
574 
575 class AttrDocVisitor {
576  public:
577   template <typename T>
578   AttrDocEntry operator()(const char* key, T* v) {
579     ObjectPtr<AttrFieldInfoNode> info = make_object<AttrFieldInfoNode>();
580     info->name = key;
581     info->type_info = TypeName<T>::value;
582     fields_.push_back(AttrFieldInfo(info));
583     return AttrDocEntry(info);
584   }
585 
586   Array<AttrFieldInfo> fields_;
587 };
588 
589 class AttrExistVisitor {
590  public:
591   std::string key_;
592   bool exist_{false};
593 
594   template <typename T>
595   AttrNopEntry operator()(const char* key, T* v) {
596     if (exist_) return AttrNopEntry();
597     if (key == key_) exist_ = true;
598     return AttrNopEntry();
599   }
600 };
601 
602 template <typename T>
603 struct AttrTriggerNonDefaultEntry {
604   using TSelf = AttrTriggerNonDefaultEntry<T>;
605   // constructor
606   AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data)
607       : visitor_(visitor), key_(key), data_(data) {}
608 
609   ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
610     if (trigger_) {
611       visitor_->Visit(key_, data_);
612     }
613   }
614   TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
615   TSelf& set_default(const T& value) {
616     if (tvm::StructuralEqual()(value, *data_)) {
617       trigger_ = false;
618     }
619     return *this;
620   }
621   TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; }
622   TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; }
623 
624  private:
625   AttrVisitor* visitor_;
626   const char* key_;
627   T* data_;
628   bool trigger_{true};
629 };
630 
631 class AttrNonDefaultVisitor {
632  public:
633   explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
634   template <typename T>
635   AttrTriggerNonDefaultEntry<T> operator()(const char* key, T* value) {
636     return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
637   }
638 
639  private:
640   AttrVisitor* visitor_;
641 };
642 }  // namespace detail
643 
644 /*!
645  * \brief The base class of the all the
646  *  Use "curiously recurring template pattern".
647  *
648  * \tparam DerivedType The final attribute type.
649  */
650 template <typename DerivedType>
651 class AttrsNode : public BaseAttrsNode {
652  public:
653   void VisitAttrs(AttrVisitor* v) {
654     ::tvm::detail::AttrNormalVisitor vis(v);
655     self()->__VisitAttrs__(vis);
656   }
657 
658   void VisitNonDefaultAttrs(AttrVisitor* v) {
659     ::tvm::detail::AttrNonDefaultVisitor vis(v);
660     self()->__VisitAttrs__(vis);
661   }
662 
663   void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
664     CHECK_EQ(args.size() % 2, 0);
665     const int kLinearSearchBound = 16;
666     int hit_count = 0;
667     // applies two stratgies to lookup
668     if (args.size() < kLinearSearchBound) {
669       // linear search.
670       auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
671         for (int i = 0; i < args.size(); i += 2) {
672           CHECK_EQ(args.type_codes[i], kTVMStr);
673           if (!std::strcmp(key, args.values[i].v_str)) {
674             *val = args[i + 1];
675             return true;
676           }
677         }
678         return false;
679       };
680       auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
681       self()->__VisitAttrs__(vis);
682       hit_count = vis.hit_count_;
683     } else {
684       // construct a map then do lookup.
685       std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
686       for (int i = 0; i < args.size(); i += 2) {
687         CHECK_EQ(args.type_codes[i], kTVMStr);
688         kwargs[args[i].operator std::string()] = args[i + 1];
689       }
690       auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) {
691         auto it = kwargs.find(key);
692         if (it != kwargs.end()) {
693           *val = it->second;
694           return true;
695         }
696         return false;
697       };
698       auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
699       self()->__VisitAttrs__(vis);
700       hit_count = vis.hit_count_;
701     }
702     // error handling, slow path
703     if (hit_count * 2 != args.size() && !allow_unknown) {
704       for (int i = 0; i < args.size(); i += 2) {
705         ::tvm::detail::AttrExistVisitor visitor;
706         visitor.key_ = args[i].operator std::string();
707         self()->__VisitAttrs__(visitor);
708         if (!visitor.exist_) {
709           std::ostringstream os;
710           os << DerivedType::_type_key << ": does not have field \'" << visitor.key_
711              << "\', Possible fields:\n";
712           os << "----------------\n";
713           this->PrintDocString(os);
714           throw AttrError(os.str());
715         }
716       }
717     }
718   }
719 
720   bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
721     DerivedType* pself = self();
722     ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
723     self()->__VisitAttrs__(visitor);
724     return visitor.result_;
725   }
726 
727   void SHashReduce(SHashReducer hash_reducer) const {
728     ::tvm::detail::AttrsSHashVisitor visitor(hash_reducer);
729     self()->__VisitAttrs__(visitor);
730   }
731 
732   Array<AttrFieldInfo> ListFieldInfo() const final {
733     ::tvm::detail::AttrDocVisitor visitor;
734     self()->__VisitAttrs__(visitor);
735     return visitor.fields_;
736   }
737 
738  private:
739   DerivedType* self() const {
740     return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
741   }
742 };
743 
744 template <typename... Args>
745 inline void BaseAttrsNode::InitBySeq(Args&&... args) {
746   runtime::PackedFunc pf(
747       [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); });
748   pf(std::forward<Args>(args)...);
749 }
750 
751 inline void BaseAttrsNode::PrintDocString(std::ostream& os) const {  // NOLINT(*)
752   Array<AttrFieldInfo> entry = this->ListFieldInfo();
753   for (AttrFieldInfo info : entry) {
754     os << info->name << " : " << info->type_info << '\n';
755     if (info->description.length() != 0) {
756       os << "    " << info->description << '\n';
757     }
758   }
759 }
760 
761 }  // namespace tvm
762 #endif  // TVM_IR_ATTRS_H_
763