1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 /*
20 * \file src/runtime/object.cc
21 * \brief Object type management system.
22 */
23 #include <dmlc/logging.h>
24 #include <tvm/runtime/object.h>
25 #include <tvm/runtime/registry.h>
26
27 #include <iostream>
28 #include <mutex>
29 #include <string>
30 #include <unordered_map>
31 #include <utility>
32 #include <vector>
33
34 #include "object_internal.h"
35 #include "runtime_base.h"
36
37 namespace tvm {
38 namespace runtime {
39
40 /*! \brief Type information */
41 struct TypeInfo {
42 /*! \brief The current index. */
43 uint32_t index{0};
44 /*! \brief Index of the parent in the type hierachy */
45 uint32_t parent_index{0};
46 // NOTE: the indices in [index, index + num_reserved_slots) are
47 // reserved for the child-class of this type.
48 /*! \brief Total number of slots reserved for the type and its children. */
49 uint32_t num_slots{0};
50 /*! \brief number of allocated child slots. */
51 uint32_t allocated_slots{0};
52 /*! \brief Whether child can overflow. */
53 bool child_slots_can_overflow{true};
54 /*! \brief name of the type. */
55 std::string name;
56 /*! \brief hash of the name */
57 size_t name_hash{0};
58 };
59
60 /*!
61 * \brief Type context that manages the type hierachy information.
62 */
63 class TypeContext {
64 public:
65 // NOTE: this is a relatively slow path for child checking
66 // Most types are already checked by the fast-path via reserved slot checking.
DerivedFrom(uint32_t child_tindex,uint32_t parent_tindex)67 bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) {
68 // invariance: child's type index is always bigger than its parent.
69 if (child_tindex < parent_tindex) return false;
70 if (child_tindex == parent_tindex) return true;
71 {
72 std::lock_guard<std::mutex> lock(mutex_);
73 CHECK_LT(child_tindex, type_table_.size());
74 while (child_tindex > parent_tindex) {
75 child_tindex = type_table_[child_tindex].parent_index;
76 }
77 }
78 return child_tindex == parent_tindex;
79 }
80
GetOrAllocRuntimeTypeIndex(const std::string & skey,uint32_t static_tindex,uint32_t parent_tindex,uint32_t num_child_slots,bool child_slots_can_overflow)81 uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, uint32_t static_tindex,
82 uint32_t parent_tindex, uint32_t num_child_slots,
83 bool child_slots_can_overflow) {
84 std::lock_guard<std::mutex> lock(mutex_);
85 auto it = type_key2index_.find(skey);
86 if (it != type_key2index_.end()) {
87 return it->second;
88 }
89 // try to allocate from parent's type table.
90 CHECK_LT(parent_tindex, type_table_.size())
91 << " skey= " << skey << "static_index=" << static_tindex;
92 TypeInfo& pinfo = type_table_[parent_tindex];
93 CHECK_EQ(pinfo.index, parent_tindex);
94
95 // if parent cannot overflow, then this class cannot.
96 if (!pinfo.child_slots_can_overflow) {
97 child_slots_can_overflow = false;
98 }
99
100 // total number of slots include the type itself.
101 uint32_t num_slots = num_child_slots + 1;
102 uint32_t allocated_tindex;
103
104 if (static_tindex != TypeIndex::kDynamic) {
105 // statically assigned type
106 allocated_tindex = static_tindex;
107 CHECK_LT(static_tindex, type_table_.size());
108 CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U)
109 << "Conflicting static index " << static_tindex << " between "
110 << type_table_[allocated_tindex].name << " and " << skey;
111 } else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) {
112 // allocate the slot from parent's reserved pool
113 allocated_tindex = parent_tindex + pinfo.allocated_slots;
114 // update parent's state
115 pinfo.allocated_slots += num_slots;
116 } else {
117 CHECK(pinfo.child_slots_can_overflow)
118 << "Reach maximum number of sub-classes for " << pinfo.name;
119 // allocate new entries.
120 allocated_tindex = type_counter_;
121 type_counter_ += num_slots;
122 CHECK_LE(type_table_.size(), type_counter_);
123 type_table_.resize(type_counter_, TypeInfo());
124 }
125 CHECK_GT(allocated_tindex, parent_tindex);
126 // initialize the slot.
127 type_table_[allocated_tindex].index = allocated_tindex;
128 type_table_[allocated_tindex].parent_index = parent_tindex;
129 type_table_[allocated_tindex].num_slots = num_slots;
130 type_table_[allocated_tindex].allocated_slots = 1;
131 type_table_[allocated_tindex].child_slots_can_overflow = child_slots_can_overflow;
132 type_table_[allocated_tindex].name = skey;
133 type_table_[allocated_tindex].name_hash = std::hash<std::string>()(skey);
134 // update the key2index mapping.
135 type_key2index_[skey] = allocated_tindex;
136 return allocated_tindex;
137 }
138
TypeIndex2Key(uint32_t tindex)139 std::string TypeIndex2Key(uint32_t tindex) {
140 std::lock_guard<std::mutex> lock(mutex_);
141 CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0)
142 << "Unknown type index " << tindex;
143 return type_table_[tindex].name;
144 }
145
TypeIndex2KeyHash(uint32_t tindex)146 size_t TypeIndex2KeyHash(uint32_t tindex) {
147 std::lock_guard<std::mutex> lock(mutex_);
148 CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0)
149 << "Unknown type index " << tindex;
150 return type_table_[tindex].name_hash;
151 }
152
TypeKey2Index(const std::string & skey)153 uint32_t TypeKey2Index(const std::string& skey) {
154 auto it = type_key2index_.find(skey);
155 CHECK(it != type_key2index_.end())
156 << "Cannot find type " << skey
157 << ". Did you forget to register the node by TVM_REGISTER_NODE_TYPE ?";
158 return it->second;
159 }
160
Dump(int min_children_count)161 void Dump(int min_children_count) {
162 std::vector<int> num_children(type_table_.size(), 0);
163 // reverse accumulation so we can get total counts in a bottom-up manner.
164 for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
165 if (it->index != 0) {
166 num_children[it->parent_index] += num_children[it->index] + 1;
167 }
168 }
169
170 for (const auto& info : type_table_) {
171 if (info.index != 0 && num_children[info.index] >= min_children_count) {
172 std::cerr << '[' << info.index << "] " << info.name
173 << "\tparent=" << type_table_[info.parent_index].name
174 << "\tnum_child_slots=" << info.num_slots - 1
175 << "\tnum_children=" << num_children[info.index] << std::endl;
176 }
177 }
178 }
179
Global()180 static TypeContext* Global() {
181 static TypeContext inst;
182 return &inst;
183 }
184
185 private:
TypeContext()186 TypeContext() {
187 type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
188 type_table_[0].name = "runtime.Object";
189 }
190 // mutex to avoid registration from multiple threads.
191 std::mutex mutex_;
192 std::atomic<uint32_t> type_counter_{TypeIndex::kStaticIndexEnd};
193 std::vector<TypeInfo> type_table_;
194 std::unordered_map<std::string, uint32_t> type_key2index_;
195 };
196
GetOrAllocRuntimeTypeIndex(const std::string & key,uint32_t static_tindex,uint32_t parent_tindex,uint32_t num_child_slots,bool child_slots_can_overflow)197 uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex,
198 uint32_t parent_tindex, uint32_t num_child_slots,
199 bool child_slots_can_overflow) {
200 return TypeContext::Global()->GetOrAllocRuntimeTypeIndex(
201 key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow);
202 }
203
DerivedFrom(uint32_t parent_tindex) const204 bool Object::DerivedFrom(uint32_t parent_tindex) const {
205 return TypeContext::Global()->DerivedFrom(this->type_index_, parent_tindex);
206 }
207
TypeIndex2Key(uint32_t tindex)208 std::string Object::TypeIndex2Key(uint32_t tindex) {
209 return TypeContext::Global()->TypeIndex2Key(tindex);
210 }
211
TypeIndex2KeyHash(uint32_t tindex)212 size_t Object::TypeIndex2KeyHash(uint32_t tindex) {
213 return TypeContext::Global()->TypeIndex2KeyHash(tindex);
214 }
215
TypeKey2Index(const std::string & key)216 uint32_t Object::TypeKey2Index(const std::string& key) {
217 return TypeContext::Global()->TypeKey2Index(key);
218 }
219
__anon2701dcb50102(ObjectRef obj) 220 TVM_REGISTER_GLOBAL("runtime.ObjectPtrHash").set_body_typed([](ObjectRef obj) {
221 return static_cast<int64_t>(ObjectPtrHash()(obj));
222 });
223
__anon2701dcb50202(int min_child_count) 224 TVM_REGISTER_GLOBAL("runtime.DumpTypeTable").set_body_typed([](int min_child_count) {
225 TypeContext::Global()->Dump(min_child_count);
226 });
227 } // namespace runtime
228 } // namespace tvm
229
TVMObjectGetTypeIndex(TVMObjectHandle obj,unsigned * out_tindex)230 int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) {
231 API_BEGIN();
232 CHECK(obj != nullptr);
233 out_tindex[0] = static_cast<tvm::runtime::Object*>(obj)->type_index();
234 API_END();
235 }
236
TVMObjectRetain(TVMObjectHandle obj)237 int TVMObjectRetain(TVMObjectHandle obj) {
238 API_BEGIN();
239 tvm::runtime::ObjectInternal::ObjectRetain(obj);
240 API_END();
241 }
242
TVMObjectFree(TVMObjectHandle obj)243 int TVMObjectFree(TVMObjectHandle obj) {
244 API_BEGIN();
245 tvm::runtime::ObjectInternal::ObjectFree(obj);
246 API_END();
247 }
248
TVMObjectDerivedFrom(uint32_t child_type_index,uint32_t parent_type_index,int * is_derived)249 int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) {
250 API_BEGIN();
251 *is_derived =
252 tvm::runtime::TypeContext::Global()->DerivedFrom(child_type_index, parent_type_index);
253 API_END();
254 }
255
TVMObjectTypeKey2Index(const char * type_key,unsigned * out_tindex)256 int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
257 API_BEGIN();
258 out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key);
259 API_END();
260 }
261