1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 /*!
20 * \file tvm/ir/attrs.h
21 * \brief Helpers for attribute objects.
22 *
23 * This module enables declaration of named attributes
24 * which support default value setup and bound checking.
25 *
26 * \code
27 * struct MyAttrs : public tvm::AttrsNode<MyAttrs> {
28 * float learning_rate;
29 * int num_hidden;
30 * String name;
31 * // declare attribute fields in header file
32 * TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") {
33 * TVM_ATTR_FIELD(num_hidden).set_lower_bound(1);
34 * TVM_ATTR_FIELD(learning_rate).set_default(0.01f);
35 * TVM_ATTR_FIELD(name).set_default("hello");
36 * }
37 * };
38 * // register it in cc file
39 * TVM_REGISTER_NODE_TYPE(MyAttrs);
40 * \endcode
41 *
42 * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD
43 */
44 #ifndef TVM_IR_ATTRS_H_
45 #define TVM_IR_ATTRS_H_
46
47 #include <dmlc/common.h>
48 #include <tvm/ir/expr.h>
49 #include <tvm/node/structural_equal.h>
50 #include <tvm/node/structural_hash.h>
51 #include <tvm/runtime/packed_func.h>
52
53 #include <functional>
54 #include <string>
55 #include <type_traits>
56 #include <unordered_map>
57 #include <utility>
58 #include <vector>
59
60 namespace tvm {
61 /*!
62 * \brief Declare an attribute function.
63 * \param ClassName The name of the class.
64 * \param TypeKey The type key to be used by the TVM node system.
65 */
66 #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
67 static constexpr const char* _type_key = TypeKey; \
68 TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
69 template <typename FVisit> \
70 void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)
71
72 /*!
73 * \brief Declare an attribute field.
74 * \param FieldName The field name.
75 */
76 #define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName)
77
78 /*!
79 * \brief Create a NodeRef type that represents null.
80 * \tparam TNodeRef the type to be created.
81 * \return A instance that will represent None.
82 */
83 template <typename TObjectRef>
NullValue()84 inline TObjectRef NullValue() {
85 static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types");
86 return TObjectRef(ObjectPtr<Object>(nullptr));
87 }
88
89 template <>
90 inline DataType NullValue<DataType>() {
91 return DataType(DataType::kHandle, 0, 0);
92 }
93
94 /*! \brief Error thrown during attribute checking. */
95 struct AttrError : public dmlc::Error {
96 /*!
97 * \brief constructor
98 * \param msg error message
99 */
AttrErrorAttrError100 explicit AttrError(std::string msg) : dmlc::Error("AttributeError:" + msg) {}
101 };
102
103 /*!
104 * \brief Information about attribute fields in string representations.
105 */
106 class AttrFieldInfoNode : public Object {
107 public:
108 /*! \brief name of the field */
109 String name;
110 /*! \brief type docstring information in str. */
111 String type_info;
112 /*! \brief detailed description of the type */
113 String description;
114
VisitAttrs(AttrVisitor * v)115 void VisitAttrs(AttrVisitor* v) {
116 v->Visit("name", &name);
117 v->Visit("type_info", &type_info);
118 v->Visit("description", &description);
119 }
120
121 static constexpr const char* _type_key = "AttrFieldInfo";
122 static constexpr bool _type_has_method_sequal_reduce = false;
123 static constexpr bool _type_has_method_shash_reduce = false;
124 TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
125 };
126
127 /*! \brief AttrFieldInfo */
128 class AttrFieldInfo : public ObjectRef {
129 public:
130 TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
131 };
132
133 /*!
134 * \brief Base class of all attribute class
135 * \note Do not subclass AttrBaseNode directly,
136 * subclass AttrsNode instead.
137 * \sa AttrsNode
138 */
139 class BaseAttrsNode : public Object {
140 public:
141 using TVMArgs = runtime::TVMArgs;
142 using TVMRetValue = runtime::TVMRetValue;
143 /*! \brief virtual destructor */
~BaseAttrsNode()144 virtual ~BaseAttrsNode() {}
145 // visit function
VisitAttrs(AttrVisitor * v)146 virtual void VisitAttrs(AttrVisitor* v) {}
147 /*!
148 * \brief Initialize the attributes by sequence of arguments
149 * \param args The postional arguments in the form
150 * [key0, value0, key1, value1, ..., key_n, value_n]
151 */
152 template <typename... Args>
153 inline void InitBySeq(Args&&... args);
154 /*!
155 * \brief Print readible docstring to ostream, add newline.
156 * \param os the stream to print the docstring to.
157 */
158 inline void PrintDocString(std::ostream& os) const; // NOLINT(*)
159 /*!
160 * \brief Visit attributes that do not equal the default value.
161 *
162 * \note This is useful to extract fields for concise printing.
163 * \param v The visitor
164 */
165 TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
166 /*!
167 * \brief Get the field information
168 * \return The fields in the Attrs.
169 */
170 TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
171 /*!
172 * \brief Initialize the attributes by arguments.
173 * \param kwargs The key value pairs for initialization.
174 * [key0, value0, key1, value1, ..., key_n, value_n]
175 * \param allow_unknown Whether allow additional unknown fields.
176 * \note This function throws when the required field is not present.
177 */
178 TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
179
180 static constexpr const bool _type_has_method_sequal_reduce = true;
181 static constexpr const bool _type_has_method_shash_reduce = true;
182 static constexpr const char* _type_key = "Attrs";
183 TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
184 };
185
186 /*!
187 * \brief Managed reference to BaseAttrsNode.
188 * \sa AttrsNode, BaseAttrsNode
189 */
190 class Attrs : public ObjectRef {
191 public:
192 TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode);
193 };
194
195 /*!
196 * \brief Specialized attribute type that is backed by a map.
197 * The DictAttrsNode implements the Attrs behavior,
198 * its fields are directly accessible via object.field_name
199 * like other normal nodes.
200 */
201 class DictAttrsNode : public BaseAttrsNode {
202 public:
203 /*! \brief internal attrs map */
204 Map<String, ObjectRef> dict;
205
SEqualReduce(const DictAttrsNode * other,SEqualReducer equal)206 bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
207 return equal(dict, other->dict);
208 }
209
SHashReduce(SHashReducer hash_reduce)210 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); }
211
212 // implementations
213 void VisitAttrs(AttrVisitor* v) final;
214 void VisitNonDefaultAttrs(AttrVisitor* v) final;
215 void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
216 Array<AttrFieldInfo> ListFieldInfo() const final;
217 // type info
218 static constexpr const char* _type_key = "DictAttrs";
219 TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
220 };
221
222 /*!
223 * \brief Managed reference to DictAttrsNode
224 * \sa DictAttrsNode.
225 */
226 class DictAttrs : public Attrs {
227 public:
228 /*!
229 * \brief Consruct a Attrs backed by DictAttrsNode.
230 * \param dict The attributes.
231 * \return The dict attributes.
232 */
233 TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);
234
235 TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
236 TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
237 };
238
239 /*!
240 * \brief Create an Attr object with all default values.
241 * \tparam TAttrNode the type to be created.
242 * \return A instance that will represent None.
243 */
244 template <typename TAttrs>
AttrsWithDefaultValues()245 inline TAttrs AttrsWithDefaultValues() {
246 static_assert(std::is_base_of<Attrs, TAttrs>::value, "Can only take attr nodes");
247 auto n = make_object<typename TAttrs::ContainerType>();
248 n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false);
249 return TAttrs(n);
250 }
251
252 // Namespace containing detail implementations
253 namespace detail {
254 using runtime::TVMArgValue;
255
256 // helper entry that does nothing in set_default/bound/describe calls.
257 struct AttrNopEntry {
258 using TSelf = AttrNopEntry;
259
describeAttrNopEntry260 TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
261 template <typename T>
set_defaultAttrNopEntry262 TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
263 return *this;
264 }
265 template <typename T>
set_lower_boundAttrNopEntry266 TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
267 return *this;
268 }
269 template <typename T>
set_upper_boundAttrNopEntry270 TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
271 return *this;
272 }
273 };
274
275 // Wrapper for normal visitor.
276 class AttrNormalVisitor {
277 public:
AttrNormalVisitor(AttrVisitor * visitor)278 explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
279 template <typename T>
operator()280 AttrNopEntry operator()(const char* key, T* value) {
281 visitor_->Visit(key, value);
282 return AttrNopEntry();
283 }
284
285 private:
286 AttrVisitor* visitor_;
287 };
288
289 class AttrsSEqualVisitor {
290 public:
291 bool result_{true};
292 // constructor
AttrsSEqualVisitor(const Object * lhs,const Object * rhs,const SEqualReducer & equal)293 AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
294 : lhs_(lhs), rhs_(rhs), equal_(equal) {}
295 template <typename T>
operator()296 AttrNopEntry operator()(const char* key, T* lhs_value) {
297 if (!result_) return AttrNopEntry();
298 const T* rhs_value = reinterpret_cast<const T*>(
299 reinterpret_cast<const char*>(rhs_) +
300 (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_)));
301 if (!equal_(*lhs_value, *rhs_value)) {
302 result_ = false;
303 }
304 return AttrNopEntry();
305 }
306
307 private:
308 const Object* lhs_;
309 const Object* rhs_;
310 const SEqualReducer& equal_;
311 };
312
313 class AttrsSHashVisitor {
314 public:
AttrsSHashVisitor(const SHashReducer & hash_reducer)315 explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {}
316
317 template <typename T>
operator()318 AttrNopEntry operator()(const char* key, T* value) {
319 hash_reducer_(*value);
320 return AttrNopEntry();
321 }
322
323 private:
324 const SHashReducer& hash_reducer_;
325 };
326
327 // helper entry that does initialization, set default.
328 template <typename T>
329 struct AttrInitEntry {
330 // The attributes
331 using TSelf = AttrInitEntry<T>;
332 // The type key
333 const char* type_key_;
334 // field name
335 const char* key_;
336 // internal value.
337 T* value_;
338 // whether the value is missing.
339 bool value_missing_{true};
340
341 AttrInitEntry() = default;
342
AttrInitEntryAttrInitEntry343 AttrInitEntry(AttrInitEntry&& other) {
344 type_key_ = other.type_key_;
345 key_ = other.key_;
346 value_ = other.value_;
347 value_missing_ = other.value_missing_;
348 // avoid unexpected throw
349 other.value_missing_ = false;
350 }
351
352 // If the value is still missing in destruction time throw an error.
~AttrInitEntryAttrInitEntry353 ~AttrInitEntry() DMLC_THROW_EXCEPTION {
354 if (value_missing_) {
355 std::ostringstream os;
356 os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization."
357 << "If the key is defined check that its type matches the declared type.";
358 throw AttrError(os.str());
359 }
360 }
361 // override fields.
362 // This function sets the lower bound of the attribute
set_lower_boundAttrInitEntry363 TSelf& set_lower_bound(const T& begin) {
364 if (this->value_missing_) return *this;
365 const T& val = *value_;
366 if (begin > val) {
367 std::ostringstream os;
368 os << type_key_ << "." << key_ << ": "
369 << "value " << val << " is smaller than the lower bound " << begin;
370 throw AttrError(os.str());
371 }
372 return *this;
373 }
374 // This function sets the upper bound of the attribute
set_upper_boundAttrInitEntry375 TSelf& set_upper_bound(const T& end) {
376 if (this->value_missing_) return *this;
377 const T& val = *value_;
378 if (val > end) {
379 std::ostringstream os;
380 os << type_key_ << "." << key_ << ": "
381 << "value " << val << " is bigger than the upper bound " << end;
382 throw AttrError(os.str());
383 }
384 return *this;
385 }
386 // set default when
set_defaultAttrInitEntry387 TSelf& set_default(const T& value) {
388 if (!value_missing_) return *this;
389 *value_ = value;
390 value_missing_ = false;
391 return *this;
392 }
describeAttrInitEntry393 TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
394 };
395
396 // Template function to allow smart conversion
397 // from Expr types into the constants.
398 template <typename T>
SetValue(T * ptr,const TVMArgValue & val)399 inline void SetValue(T* ptr, const TVMArgValue& val) {
400 *ptr = val.operator T();
401 }
402
403 template <typename T>
SetIntValue(T * ptr,const TVMArgValue & val)404 inline void SetIntValue(T* ptr, const TVMArgValue& val) {
405 if (val.type_code() == kDLInt) {
406 *ptr = static_cast<T>(val.value().v_int64);
407 } else {
408 IntImm expr = val;
409 *ptr = static_cast<T>(expr->value);
410 }
411 }
412
413 template <>
414 inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
415 if (String::CanConvertFrom(val)) {
416 *ptr = val.operator std::string();
417 } else {
418 LOG(FATAL) << "Expect str";
419 }
420 }
421
422 template <>
423 inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
424 if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
425 *ptr = val.operator double();
426 } else {
427 ObjectRef expr = val;
428 CHECK(expr.defined());
429 if (const IntImmNode* op = expr.as<IntImmNode>()) {
430 *ptr = static_cast<double>(op->value);
431 } else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
432 *ptr = static_cast<double>(op->value);
433 } else {
434 LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
435 }
436 }
437 }
438 template <>
439 inline void SetValue<int>(int* ptr, const TVMArgValue& val) {
440 SetIntValue(ptr, val);
441 }
442 template <>
443 inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) {
444 SetIntValue(ptr, val);
445 }
446 template <>
447 inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) {
448 SetIntValue(ptr, val);
449 }
450 template <>
451 inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) {
452 SetIntValue(ptr, val);
453 }
454
455 // Visitor for value initialization
456 template <typename FFind>
457 class AttrInitVisitor {
458 public:
459 // Counter of number of matched attributes during visit.
460 // This is used to decide if there is additional unmatched attributes.
461 size_t hit_count_{0};
462 // constructor
AttrInitVisitor(const char * type_key,FFind ffind)463 AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
464
465 template <typename T>
operator()466 AttrInitEntry<T> operator()(const char* key, T* value) {
467 TVMArgValue val;
468 AttrInitEntry<T> opt;
469 opt.type_key_ = type_key_;
470 opt.key_ = key;
471 opt.value_ = value;
472 if (ffind_(key, &val)) {
473 SetValue(value, val);
474 opt.value_missing_ = false;
475 ++hit_count_;
476 } else {
477 opt.value_missing_ = true;
478 }
479 #if defined(__GNUC__)
480 #pragma GCC diagnostic ignored "-Wpragmas"
481 #pragma GCC diagnostic ignored "-Wpessimizing-move"
482 #endif
483 return std::move(opt);
484 }
485
486 private:
487 // the type key
488 const char* type_key_;
489 FFind ffind_;
490 };
491
492 template <typename FFind>
CreateInitVisitor(const char * type_key,FFind ffind)493 inline AttrInitVisitor<FFind> CreateInitVisitor(const char* type_key, FFind ffind) {
494 return AttrInitVisitor<FFind>(type_key, ffind);
495 }
496
497 /*!
498 * \brief Helper struct to get the type name known to tvm.
499 * \tparam T the type we are interested in.
500 */
501 template <typename T>
502 struct TypeName {
503 static constexpr const char* value = T::ContainerType::_type_key;
504 };
505
506 template <>
507 struct TypeName<int> {
508 static constexpr const char* value = "int";
509 };
510
511 template <>
512 struct TypeName<int64_t> {
513 static constexpr const char* value = "int64";
514 };
515
516 template <>
517 struct TypeName<uint64_t> {
518 static constexpr const char* value = "uint64_t";
519 };
520
521 template <>
522 struct TypeName<DataType> {
523 static constexpr const char* value = "DataType";
524 };
525
526 template <>
527 struct TypeName<std::string> {
528 static constexpr const char* value = "str";
529 };
530
531 template <>
532 struct TypeName<bool> {
533 static constexpr const char* value = "bool";
534 };
535
536 template <>
537 struct TypeName<void*> {
538 static constexpr const char* value = "handle";
539 };
540
541 template <>
542 struct TypeName<double> {
543 static constexpr const char* value = "double";
544 };
545
546 class AttrDocEntry {
547 public:
548 using TSelf = AttrDocEntry;
549
550 explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info) : info_(info) {}
551 TSelf& describe(const char* str) {
552 info_->description = str;
553 return *this;
554 }
555 template <typename T>
556 TSelf& set_default(const T& value) {
557 std::ostringstream os;
558 os << info_->type_info << ", default=" << value;
559 info_->type_info = os.str();
560 return *this;
561 }
562 template <typename T>
563 TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
564 return *this;
565 }
566 template <typename T>
567 TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
568 return *this;
569 }
570
571 private:
572 ObjectPtr<AttrFieldInfoNode> info_;
573 };
574
575 class AttrDocVisitor {
576 public:
577 template <typename T>
578 AttrDocEntry operator()(const char* key, T* v) {
579 ObjectPtr<AttrFieldInfoNode> info = make_object<AttrFieldInfoNode>();
580 info->name = key;
581 info->type_info = TypeName<T>::value;
582 fields_.push_back(AttrFieldInfo(info));
583 return AttrDocEntry(info);
584 }
585
586 Array<AttrFieldInfo> fields_;
587 };
588
589 class AttrExistVisitor {
590 public:
591 std::string key_;
592 bool exist_{false};
593
594 template <typename T>
595 AttrNopEntry operator()(const char* key, T* v) {
596 if (exist_) return AttrNopEntry();
597 if (key == key_) exist_ = true;
598 return AttrNopEntry();
599 }
600 };
601
602 template <typename T>
603 struct AttrTriggerNonDefaultEntry {
604 using TSelf = AttrTriggerNonDefaultEntry<T>;
605 // constructor
606 AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data)
607 : visitor_(visitor), key_(key), data_(data) {}
608
609 ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
610 if (trigger_) {
611 visitor_->Visit(key_, data_);
612 }
613 }
614 TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
615 TSelf& set_default(const T& value) {
616 if (tvm::StructuralEqual()(value, *data_)) {
617 trigger_ = false;
618 }
619 return *this;
620 }
621 TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; }
622 TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; }
623
624 private:
625 AttrVisitor* visitor_;
626 const char* key_;
627 T* data_;
628 bool trigger_{true};
629 };
630
631 class AttrNonDefaultVisitor {
632 public:
633 explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
634 template <typename T>
635 AttrTriggerNonDefaultEntry<T> operator()(const char* key, T* value) {
636 return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
637 }
638
639 private:
640 AttrVisitor* visitor_;
641 };
642 } // namespace detail
643
644 /*!
645 * \brief The base class of the all the
646 * Use "curiously recurring template pattern".
647 *
648 * \tparam DerivedType The final attribute type.
649 */
650 template <typename DerivedType>
651 class AttrsNode : public BaseAttrsNode {
652 public:
653 void VisitAttrs(AttrVisitor* v) {
654 ::tvm::detail::AttrNormalVisitor vis(v);
655 self()->__VisitAttrs__(vis);
656 }
657
658 void VisitNonDefaultAttrs(AttrVisitor* v) {
659 ::tvm::detail::AttrNonDefaultVisitor vis(v);
660 self()->__VisitAttrs__(vis);
661 }
662
663 void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
664 CHECK_EQ(args.size() % 2, 0);
665 const int kLinearSearchBound = 16;
666 int hit_count = 0;
667 // applies two stratgies to lookup
668 if (args.size() < kLinearSearchBound) {
669 // linear search.
670 auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
671 for (int i = 0; i < args.size(); i += 2) {
672 CHECK_EQ(args.type_codes[i], kTVMStr);
673 if (!std::strcmp(key, args.values[i].v_str)) {
674 *val = args[i + 1];
675 return true;
676 }
677 }
678 return false;
679 };
680 auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
681 self()->__VisitAttrs__(vis);
682 hit_count = vis.hit_count_;
683 } else {
684 // construct a map then do lookup.
685 std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
686 for (int i = 0; i < args.size(); i += 2) {
687 CHECK_EQ(args.type_codes[i], kTVMStr);
688 kwargs[args[i].operator std::string()] = args[i + 1];
689 }
690 auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) {
691 auto it = kwargs.find(key);
692 if (it != kwargs.end()) {
693 *val = it->second;
694 return true;
695 }
696 return false;
697 };
698 auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
699 self()->__VisitAttrs__(vis);
700 hit_count = vis.hit_count_;
701 }
702 // error handling, slow path
703 if (hit_count * 2 != args.size() && !allow_unknown) {
704 for (int i = 0; i < args.size(); i += 2) {
705 ::tvm::detail::AttrExistVisitor visitor;
706 visitor.key_ = args[i].operator std::string();
707 self()->__VisitAttrs__(visitor);
708 if (!visitor.exist_) {
709 std::ostringstream os;
710 os << DerivedType::_type_key << ": does not have field \'" << visitor.key_
711 << "\', Possible fields:\n";
712 os << "----------------\n";
713 this->PrintDocString(os);
714 throw AttrError(os.str());
715 }
716 }
717 }
718 }
719
720 bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
721 DerivedType* pself = self();
722 ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
723 self()->__VisitAttrs__(visitor);
724 return visitor.result_;
725 }
726
727 void SHashReduce(SHashReducer hash_reducer) const {
728 ::tvm::detail::AttrsSHashVisitor visitor(hash_reducer);
729 self()->__VisitAttrs__(visitor);
730 }
731
732 Array<AttrFieldInfo> ListFieldInfo() const final {
733 ::tvm::detail::AttrDocVisitor visitor;
734 self()->__VisitAttrs__(visitor);
735 return visitor.fields_;
736 }
737
738 private:
739 DerivedType* self() const {
740 return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
741 }
742 };
743
744 template <typename... Args>
745 inline void BaseAttrsNode::InitBySeq(Args&&... args) {
746 runtime::PackedFunc pf(
747 [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); });
748 pf(std::forward<Args>(args)...);
749 }
750
751 inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*)
752 Array<AttrFieldInfo> entry = this->ListFieldInfo();
753 for (AttrFieldInfo info : entry) {
754 os << info->name << " : " << info->type_info << '\n';
755 if (info->description.length() != 0) {
756 os << " " << info->description << '\n';
757 }
758 }
759 }
760
761 } // namespace tvm
762 #endif // TVM_IR_ATTRS_H_
763