1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*
20  * \file src/runtime/object.cc
21  * \brief Object type management system.
22  */
23 #include <dmlc/logging.h>
24 #include <tvm/runtime/object.h>
25 #include <tvm/runtime/registry.h>
26 
27 #include <iostream>
28 #include <mutex>
29 #include <string>
30 #include <unordered_map>
31 #include <utility>
32 #include <vector>
33 
34 #include "object_internal.h"
35 #include "runtime_base.h"
36 
37 namespace tvm {
38 namespace runtime {
39 
40 /*! \brief Type information */
41 struct TypeInfo {
42   /*! \brief The current index. */
43   uint32_t index{0};
44   /*! \brief Index of the parent in the type hierachy */
45   uint32_t parent_index{0};
46   // NOTE: the indices in [index, index + num_reserved_slots) are
47   // reserved for the child-class of this type.
48   /*! \brief Total number of slots reserved for the type and its children. */
49   uint32_t num_slots{0};
50   /*! \brief number of allocated child slots. */
51   uint32_t allocated_slots{0};
52   /*! \brief Whether child can overflow. */
53   bool child_slots_can_overflow{true};
54   /*! \brief name of the type. */
55   std::string name;
56   /*! \brief hash of the name */
57   size_t name_hash{0};
58 };
59 
60 /*!
61  * \brief Type context that manages the type hierachy information.
62  */
63 class TypeContext {
64  public:
65   // NOTE: this is a relatively slow path for child checking
66   // Most types are already checked by the fast-path via reserved slot checking.
DerivedFrom(uint32_t child_tindex,uint32_t parent_tindex)67   bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) {
68     // invariance: child's type index is always bigger than its parent.
69     if (child_tindex < parent_tindex) return false;
70     if (child_tindex == parent_tindex) return true;
71     {
72       std::lock_guard<std::mutex> lock(mutex_);
73       CHECK_LT(child_tindex, type_table_.size());
74       while (child_tindex > parent_tindex) {
75         child_tindex = type_table_[child_tindex].parent_index;
76       }
77     }
78     return child_tindex == parent_tindex;
79   }
80 
GetOrAllocRuntimeTypeIndex(const std::string & skey,uint32_t static_tindex,uint32_t parent_tindex,uint32_t num_child_slots,bool child_slots_can_overflow)81   uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, uint32_t static_tindex,
82                                       uint32_t parent_tindex, uint32_t num_child_slots,
83                                       bool child_slots_can_overflow) {
84     std::lock_guard<std::mutex> lock(mutex_);
85     auto it = type_key2index_.find(skey);
86     if (it != type_key2index_.end()) {
87       return it->second;
88     }
89     // try to allocate from parent's type table.
90     CHECK_LT(parent_tindex, type_table_.size())
91         << " skey= " << skey << "static_index=" << static_tindex;
92     TypeInfo& pinfo = type_table_[parent_tindex];
93     CHECK_EQ(pinfo.index, parent_tindex);
94 
95     // if parent cannot overflow, then this class cannot.
96     if (!pinfo.child_slots_can_overflow) {
97       child_slots_can_overflow = false;
98     }
99 
100     // total number of slots include the type itself.
101     uint32_t num_slots = num_child_slots + 1;
102     uint32_t allocated_tindex;
103 
104     if (static_tindex != TypeIndex::kDynamic) {
105       // statically assigned type
106       allocated_tindex = static_tindex;
107       CHECK_LT(static_tindex, type_table_.size());
108       CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U)
109           << "Conflicting static index " << static_tindex << " between "
110           << type_table_[allocated_tindex].name << " and " << skey;
111     } else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) {
112       // allocate the slot from parent's reserved pool
113       allocated_tindex = parent_tindex + pinfo.allocated_slots;
114       // update parent's state
115       pinfo.allocated_slots += num_slots;
116     } else {
117       CHECK(pinfo.child_slots_can_overflow)
118           << "Reach maximum number of sub-classes for " << pinfo.name;
119       // allocate new entries.
120       allocated_tindex = type_counter_;
121       type_counter_ += num_slots;
122       CHECK_LE(type_table_.size(), type_counter_);
123       type_table_.resize(type_counter_, TypeInfo());
124     }
125     CHECK_GT(allocated_tindex, parent_tindex);
126     // initialize the slot.
127     type_table_[allocated_tindex].index = allocated_tindex;
128     type_table_[allocated_tindex].parent_index = parent_tindex;
129     type_table_[allocated_tindex].num_slots = num_slots;
130     type_table_[allocated_tindex].allocated_slots = 1;
131     type_table_[allocated_tindex].child_slots_can_overflow = child_slots_can_overflow;
132     type_table_[allocated_tindex].name = skey;
133     type_table_[allocated_tindex].name_hash = std::hash<std::string>()(skey);
134     // update the key2index mapping.
135     type_key2index_[skey] = allocated_tindex;
136     return allocated_tindex;
137   }
138 
TypeIndex2Key(uint32_t tindex)139   std::string TypeIndex2Key(uint32_t tindex) {
140     std::lock_guard<std::mutex> lock(mutex_);
141     CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0)
142         << "Unknown type index " << tindex;
143     return type_table_[tindex].name;
144   }
145 
TypeIndex2KeyHash(uint32_t tindex)146   size_t TypeIndex2KeyHash(uint32_t tindex) {
147     std::lock_guard<std::mutex> lock(mutex_);
148     CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0)
149         << "Unknown type index " << tindex;
150     return type_table_[tindex].name_hash;
151   }
152 
TypeKey2Index(const std::string & skey)153   uint32_t TypeKey2Index(const std::string& skey) {
154     auto it = type_key2index_.find(skey);
155     CHECK(it != type_key2index_.end())
156         << "Cannot find type " << skey
157         << ". Did you forget to register the node by TVM_REGISTER_NODE_TYPE ?";
158     return it->second;
159   }
160 
Dump(int min_children_count)161   void Dump(int min_children_count) {
162     std::vector<int> num_children(type_table_.size(), 0);
163     // reverse accumulation so we can get total counts in a bottom-up manner.
164     for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
165       if (it->index != 0) {
166         num_children[it->parent_index] += num_children[it->index] + 1;
167       }
168     }
169 
170     for (const auto& info : type_table_) {
171       if (info.index != 0 && num_children[info.index] >= min_children_count) {
172         std::cerr << '[' << info.index << "] " << info.name
173                   << "\tparent=" << type_table_[info.parent_index].name
174                   << "\tnum_child_slots=" << info.num_slots - 1
175                   << "\tnum_children=" << num_children[info.index] << std::endl;
176       }
177     }
178   }
179 
Global()180   static TypeContext* Global() {
181     static TypeContext inst;
182     return &inst;
183   }
184 
185  private:
TypeContext()186   TypeContext() {
187     type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
188     type_table_[0].name = "runtime.Object";
189   }
190   // mutex to avoid registration from multiple threads.
191   std::mutex mutex_;
192   std::atomic<uint32_t> type_counter_{TypeIndex::kStaticIndexEnd};
193   std::vector<TypeInfo> type_table_;
194   std::unordered_map<std::string, uint32_t> type_key2index_;
195 };
196 
GetOrAllocRuntimeTypeIndex(const std::string & key,uint32_t static_tindex,uint32_t parent_tindex,uint32_t num_child_slots,bool child_slots_can_overflow)197 uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex,
198                                             uint32_t parent_tindex, uint32_t num_child_slots,
199                                             bool child_slots_can_overflow) {
200   return TypeContext::Global()->GetOrAllocRuntimeTypeIndex(
201       key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow);
202 }
203 
DerivedFrom(uint32_t parent_tindex) const204 bool Object::DerivedFrom(uint32_t parent_tindex) const {
205   return TypeContext::Global()->DerivedFrom(this->type_index_, parent_tindex);
206 }
207 
TypeIndex2Key(uint32_t tindex)208 std::string Object::TypeIndex2Key(uint32_t tindex) {
209   return TypeContext::Global()->TypeIndex2Key(tindex);
210 }
211 
TypeIndex2KeyHash(uint32_t tindex)212 size_t Object::TypeIndex2KeyHash(uint32_t tindex) {
213   return TypeContext::Global()->TypeIndex2KeyHash(tindex);
214 }
215 
TypeKey2Index(const std::string & key)216 uint32_t Object::TypeKey2Index(const std::string& key) {
217   return TypeContext::Global()->TypeKey2Index(key);
218 }
219 
__anon2701dcb50102(ObjectRef obj) 220 TVM_REGISTER_GLOBAL("runtime.ObjectPtrHash").set_body_typed([](ObjectRef obj) {
221   return static_cast<int64_t>(ObjectPtrHash()(obj));
222 });
223 
__anon2701dcb50202(int min_child_count) 224 TVM_REGISTER_GLOBAL("runtime.DumpTypeTable").set_body_typed([](int min_child_count) {
225   TypeContext::Global()->Dump(min_child_count);
226 });
227 }  // namespace runtime
228 }  // namespace tvm
229 
TVMObjectGetTypeIndex(TVMObjectHandle obj,unsigned * out_tindex)230 int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) {
231   API_BEGIN();
232   CHECK(obj != nullptr);
233   out_tindex[0] = static_cast<tvm::runtime::Object*>(obj)->type_index();
234   API_END();
235 }
236 
TVMObjectRetain(TVMObjectHandle obj)237 int TVMObjectRetain(TVMObjectHandle obj) {
238   API_BEGIN();
239   tvm::runtime::ObjectInternal::ObjectRetain(obj);
240   API_END();
241 }
242 
TVMObjectFree(TVMObjectHandle obj)243 int TVMObjectFree(TVMObjectHandle obj) {
244   API_BEGIN();
245   tvm::runtime::ObjectInternal::ObjectFree(obj);
246   API_END();
247 }
248 
TVMObjectDerivedFrom(uint32_t child_type_index,uint32_t parent_type_index,int * is_derived)249 int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) {
250   API_BEGIN();
251   *is_derived =
252       tvm::runtime::TypeContext::Global()->DerivedFrom(child_type_index, parent_type_index);
253   API_END();
254 }
255 
TVMObjectTypeKey2Index(const char * type_key,unsigned * out_tindex)256 int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
257   API_BEGIN();
258   out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key);
259   API_END();
260 }
261