1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 /*
20 * \file object.cc
21 * \brief Object type management system.
22 */
23 // Acknowledgement: This file originates from incubator-tvm
24
25 #include <dmlc/logging.h>
26 #include <mxnet/runtime/c_runtime_api.h>
27 #include <mxnet/runtime/object.h>
28 #include <mutex>
29 #include <string>
30 #include <vector>
31 #include <utility>
32 #include <unordered_map>
33
34 #include "../c_api/c_api_common.h"
35 #include "./object_internal.h"
36
37 namespace mxnet {
38 namespace runtime {
39
40 /*! \brief Type information */
41 struct TypeInfo {
42 /*! \brief The current index. */
43 uint32_t index{0};
44 /*! \brief Index of the parent in the type hierachy */
45 uint32_t parent_index{0};
46 // NOTE: the indices in [index, index + num_reserved_slots) are
47 // reserved for the child-class of this type.
48 /*! \brief Total number of slots reserved for the type and its children. */
49 uint32_t num_slots{0};
50 /*! \brief number of allocated child slots. */
51 uint32_t allocated_slots{0};
52 /*! \brief Whether child can overflow. */
53 bool child_slots_can_overflow{true};
54 /*! \brief name of the type. */
55 std::string name;
56 /*! \brief hash of the name */
57 size_t name_hash{0};
58 };
59
60 /*!
61 * \brief Type context that manages the type hierachy information.
62 */
63 class TypeContext {
64 public:
65 // NOTE: this is a relatively slow path for child checking
66 // Most types are already checked by the fast-path via reserved slot checking.
DerivedFrom(uint32_t child_tindex,uint32_t parent_tindex)67 bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) {
68 // invariance: child's type index is always bigger than its parent.
69 if (child_tindex < parent_tindex) return false;
70 if (child_tindex == parent_tindex) return true;
71 {
72 std::lock_guard<std::mutex> lock(mutex_);
73 CHECK_LT(child_tindex, type_table_.size());
74 while (child_tindex > parent_tindex) {
75 child_tindex = type_table_[child_tindex].parent_index;
76 }
77 }
78 return child_tindex == parent_tindex;
79 }
80
GetOrAllocRuntimeTypeIndex(const std::string & skey,uint32_t static_tindex,uint32_t parent_tindex,uint32_t num_child_slots,bool child_slots_can_overflow)81 uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey,
82 uint32_t static_tindex,
83 uint32_t parent_tindex,
84 uint32_t num_child_slots,
85 bool child_slots_can_overflow) {
86 std::lock_guard<std::mutex> lock(mutex_);
87 auto it = type_key2index_.find(skey);
88 if (it != type_key2index_.end()) {
89 return it->second;
90 }
91 // try to allocate from parent's type table.
92 CHECK_LT(parent_tindex, type_table_.size());
93 TypeInfo& pinfo = type_table_[parent_tindex];
94 CHECK_EQ(pinfo.index, parent_tindex);
95
96 // if parent cannot overflow, then this class cannot.
97 if (!pinfo.child_slots_can_overflow) {
98 child_slots_can_overflow = false;
99 }
100
101 // total number of slots include the type itself.
102 uint32_t num_slots = num_child_slots + 1;
103 uint32_t allocated_tindex;
104
105 if (static_tindex != TypeIndex::kDynamic) {
106 // statically assigned type
107 allocated_tindex = static_tindex;
108 CHECK_LT(static_tindex, type_table_.size());
109 CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U)
110 << "Conflicting static index " << static_tindex
111 << " between " << type_table_[allocated_tindex].name
112 << " and "
113 << skey;
114 } else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
115 // allocate the slot from parent's reserved pool
116 allocated_tindex = parent_tindex + pinfo.allocated_slots;
117 // update parent's state
118 pinfo.allocated_slots += num_slots;
119 } else {
120 CHECK(pinfo.child_slots_can_overflow)
121 << "Reach maximum number of sub-classes for " << pinfo.name;
122 // allocate new entries.
123 allocated_tindex = type_counter_;
124 type_counter_ += num_slots;
125 CHECK_LE(type_table_.size(), allocated_tindex);
126 type_table_.resize(allocated_tindex + 1, TypeInfo());
127 }
128 CHECK_GT(allocated_tindex, parent_tindex);
129 // initialize the slot.
130 type_table_[allocated_tindex].index = allocated_tindex;
131 type_table_[allocated_tindex].parent_index = parent_tindex;
132 type_table_[allocated_tindex].num_slots = num_slots;
133 type_table_[allocated_tindex].allocated_slots = 1;
134 type_table_[allocated_tindex].child_slots_can_overflow =
135 child_slots_can_overflow;
136 type_table_[allocated_tindex].name = skey;
137 type_table_[allocated_tindex].name_hash = std::hash<std::string>()(skey);
138 // update the key2index mapping.
139 type_key2index_[skey] = allocated_tindex;
140 return allocated_tindex;
141 }
142
TypeIndex2Key(uint32_t tindex)143 std::string TypeIndex2Key(uint32_t tindex) {
144 std::lock_guard<std::mutex> lock(mutex_);
145 CHECK(tindex < type_table_.size() &&
146 type_table_[tindex].allocated_slots != 0)
147 << "Unknown type index " << tindex;
148 return type_table_[tindex].name;
149 }
150
TypeIndex2KeyHash(uint32_t tindex)151 size_t TypeIndex2KeyHash(uint32_t tindex) {
152 std::lock_guard<std::mutex> lock(mutex_);
153 CHECK(tindex < type_table_.size() &&
154 type_table_[tindex].allocated_slots != 0)
155 << "Unknown type index " << tindex;
156 return type_table_[tindex].name_hash;
157 }
158
TypeKey2Index(const std::string & skey)159 uint32_t TypeKey2Index(const std::string& skey) {
160 auto it = type_key2index_.find(skey);
161 CHECK(it != type_key2index_.end())
162 << "Cannot find type " << skey;
163 return it->second;
164 }
165
Global()166 static TypeContext* Global() {
167 static TypeContext inst;
168 return &inst;
169 }
170
171 private:
TypeContext()172 TypeContext() {
173 type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
174 }
175 // mutex to avoid registration from multiple threads.
176 std::mutex mutex_;
177 std::atomic<uint32_t> type_counter_{TypeIndex::kStaticIndexEnd};
178 std::vector<TypeInfo> type_table_;
179 std::unordered_map<std::string, uint32_t> type_key2index_;
180 };
181
GetOrAllocRuntimeTypeIndex(const std::string & key,uint32_t static_tindex,uint32_t parent_tindex,uint32_t num_child_slots,bool child_slots_can_overflow)182 uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key,
183 uint32_t static_tindex,
184 uint32_t parent_tindex,
185 uint32_t num_child_slots,
186 bool child_slots_can_overflow) {
187 return TypeContext::Global()->GetOrAllocRuntimeTypeIndex(
188 key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow);
189 }
190
DerivedFrom(uint32_t parent_tindex) const191 bool Object::DerivedFrom(uint32_t parent_tindex) const {
192 return TypeContext::Global()->DerivedFrom(
193 this->type_index_, parent_tindex);
194 }
195
TypeIndex2Key(uint32_t tindex)196 std::string Object::TypeIndex2Key(uint32_t tindex) {
197 return TypeContext::Global()->TypeIndex2Key(tindex);
198 }
199
TypeIndex2KeyHash(uint32_t tindex)200 size_t Object::TypeIndex2KeyHash(uint32_t tindex) {
201 return TypeContext::Global()->TypeIndex2KeyHash(tindex);
202 }
203
TypeKey2Index(const std::string & key)204 uint32_t Object::TypeKey2Index(const std::string& key) {
205 return TypeContext::Global()->TypeKey2Index(key);
206 }
207
208 } // namespace runtime
209 } // namespace mxnet
210
MXNetObjectFree(MXNetObjectHandle obj)211 int MXNetObjectFree(MXNetObjectHandle obj) {
212 API_BEGIN();
213 mxnet::runtime::ObjectInternal::ObjectFree(obj);
214 API_END();
215 }
216