1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*
20  * \file object.cc
21  * \brief Object type management system.
22  */
23 // Acknowledgement: This file originates from incubator-tvm
24 
25 #include <dmlc/logging.h>
26 #include <mxnet/runtime/c_runtime_api.h>
27 #include <mxnet/runtime/object.h>
28 #include <mutex>
29 #include <string>
30 #include <vector>
31 #include <utility>
32 #include <unordered_map>
33 
34 #include "../c_api/c_api_common.h"
35 #include "./object_internal.h"
36 
37 namespace mxnet {
38 namespace runtime {
39 
40 /*! \brief Type information */
41 struct TypeInfo {
42   /*! \brief The current index. */
43   uint32_t index{0};
44   /*! \brief Index of the parent in the type hierachy */
45   uint32_t parent_index{0};
46   // NOTE: the indices in [index, index + num_reserved_slots) are
47   // reserved for the child-class of this type.
48   /*! \brief Total number of slots reserved for the type and its children. */
49   uint32_t num_slots{0};
50   /*! \brief number of allocated child slots. */
51   uint32_t allocated_slots{0};
52   /*! \brief Whether child can overflow. */
53   bool child_slots_can_overflow{true};
54   /*! \brief name of the type. */
55   std::string name;
56   /*! \brief hash of the name */
57   size_t name_hash{0};
58 };
59 
60 /*!
61  * \brief Type context that manages the type hierachy information.
62  */
63 class TypeContext {
64  public:
65   // NOTE: this is a relatively slow path for child checking
66   // Most types are already checked by the fast-path via reserved slot checking.
DerivedFrom(uint32_t child_tindex,uint32_t parent_tindex)67   bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) {
68     // invariance: child's type index is always bigger than its parent.
69     if (child_tindex < parent_tindex) return false;
70     if (child_tindex == parent_tindex) return true;
71     {
72       std::lock_guard<std::mutex> lock(mutex_);
73       CHECK_LT(child_tindex, type_table_.size());
74       while (child_tindex > parent_tindex) {
75         child_tindex = type_table_[child_tindex].parent_index;
76       }
77     }
78     return child_tindex == parent_tindex;
79   }
80 
GetOrAllocRuntimeTypeIndex(const std::string & skey,uint32_t static_tindex,uint32_t parent_tindex,uint32_t num_child_slots,bool child_slots_can_overflow)81   uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey,
82                                       uint32_t static_tindex,
83                                       uint32_t parent_tindex,
84                                       uint32_t num_child_slots,
85                                       bool child_slots_can_overflow) {
86     std::lock_guard<std::mutex> lock(mutex_);
87     auto it = type_key2index_.find(skey);
88     if (it != type_key2index_.end()) {
89       return it->second;
90     }
91     // try to allocate from parent's type table.
92     CHECK_LT(parent_tindex, type_table_.size());
93     TypeInfo& pinfo = type_table_[parent_tindex];
94     CHECK_EQ(pinfo.index, parent_tindex);
95 
96     // if parent cannot overflow, then this class cannot.
97     if (!pinfo.child_slots_can_overflow) {
98       child_slots_can_overflow = false;
99     }
100 
101     // total number of slots include the type itself.
102     uint32_t num_slots = num_child_slots + 1;
103     uint32_t allocated_tindex;
104 
105     if (static_tindex != TypeIndex::kDynamic) {
106       // statically assigned type
107       allocated_tindex = static_tindex;
108       CHECK_LT(static_tindex, type_table_.size());
109       CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U)
110           << "Conflicting static index " << static_tindex
111           << " between " << type_table_[allocated_tindex].name
112           << " and "
113           << skey;
114     } else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
115       // allocate the slot from parent's reserved pool
116       allocated_tindex = parent_tindex + pinfo.allocated_slots;
117       // update parent's state
118       pinfo.allocated_slots += num_slots;
119     } else {
120       CHECK(pinfo.child_slots_can_overflow)
121           << "Reach maximum number of sub-classes for " << pinfo.name;
122       // allocate new entries.
123       allocated_tindex = type_counter_;
124       type_counter_ += num_slots;
125       CHECK_LE(type_table_.size(), allocated_tindex);
126       type_table_.resize(allocated_tindex + 1, TypeInfo());
127     }
128     CHECK_GT(allocated_tindex, parent_tindex);
129     // initialize the slot.
130     type_table_[allocated_tindex].index = allocated_tindex;
131     type_table_[allocated_tindex].parent_index = parent_tindex;
132     type_table_[allocated_tindex].num_slots = num_slots;
133     type_table_[allocated_tindex].allocated_slots = 1;
134     type_table_[allocated_tindex].child_slots_can_overflow =
135         child_slots_can_overflow;
136     type_table_[allocated_tindex].name = skey;
137     type_table_[allocated_tindex].name_hash = std::hash<std::string>()(skey);
138     // update the key2index mapping.
139     type_key2index_[skey] = allocated_tindex;
140     return allocated_tindex;
141   }
142 
TypeIndex2Key(uint32_t tindex)143   std::string TypeIndex2Key(uint32_t tindex) {
144     std::lock_guard<std::mutex> lock(mutex_);
145     CHECK(tindex < type_table_.size() &&
146           type_table_[tindex].allocated_slots != 0)
147         << "Unknown type index " << tindex;
148     return type_table_[tindex].name;
149   }
150 
TypeIndex2KeyHash(uint32_t tindex)151   size_t TypeIndex2KeyHash(uint32_t tindex) {
152     std::lock_guard<std::mutex> lock(mutex_);
153     CHECK(tindex < type_table_.size() &&
154           type_table_[tindex].allocated_slots != 0)
155         << "Unknown type index " << tindex;
156     return type_table_[tindex].name_hash;
157   }
158 
TypeKey2Index(const std::string & skey)159   uint32_t TypeKey2Index(const std::string& skey) {
160     auto it = type_key2index_.find(skey);
161     CHECK(it != type_key2index_.end())
162         << "Cannot find type " << skey;
163     return it->second;
164   }
165 
Global()166   static TypeContext* Global() {
167     static TypeContext inst;
168     return &inst;
169   }
170 
171  private:
TypeContext()172   TypeContext() {
173     type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
174   }
175   // mutex to avoid registration from multiple threads.
176   std::mutex mutex_;
177   std::atomic<uint32_t> type_counter_{TypeIndex::kStaticIndexEnd};
178   std::vector<TypeInfo> type_table_;
179   std::unordered_map<std::string, uint32_t> type_key2index_;
180 };
181 
GetOrAllocRuntimeTypeIndex(const std::string & key,uint32_t static_tindex,uint32_t parent_tindex,uint32_t num_child_slots,bool child_slots_can_overflow)182 uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key,
183                                             uint32_t static_tindex,
184                                             uint32_t parent_tindex,
185                                             uint32_t num_child_slots,
186                                             bool child_slots_can_overflow) {
187   return TypeContext::Global()->GetOrAllocRuntimeTypeIndex(
188       key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow);
189 }
190 
DerivedFrom(uint32_t parent_tindex) const191 bool Object::DerivedFrom(uint32_t parent_tindex) const {
192   return TypeContext::Global()->DerivedFrom(
193       this->type_index_, parent_tindex);
194 }
195 
TypeIndex2Key(uint32_t tindex)196 std::string Object::TypeIndex2Key(uint32_t tindex) {
197   return TypeContext::Global()->TypeIndex2Key(tindex);
198 }
199 
TypeIndex2KeyHash(uint32_t tindex)200 size_t Object::TypeIndex2KeyHash(uint32_t tindex) {
201   return TypeContext::Global()->TypeIndex2KeyHash(tindex);
202 }
203 
TypeKey2Index(const std::string & key)204 uint32_t Object::TypeKey2Index(const std::string& key) {
205   return TypeContext::Global()->TypeKey2Index(key);
206 }
207 
208 }  // namespace runtime
209 }  // namespace mxnet
210 
MXNetObjectFree(MXNetObjectHandle obj)211 int MXNetObjectFree(MXNetObjectHandle obj) {
212   API_BEGIN();
213   mxnet::runtime::ObjectInternal::ObjectFree(obj);
214   API_END();
215 }
216