1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file src/tvm/relay/op.cc
22 * \brief Resolve incomplete types to complete types.
23 */
24 #include <tvm/relay/op.h>
25 #include <tvm/relay/type.h>
26 #include <tvm/runtime/module.h>
27 #include <tvm/runtime/packed_func.h>
28
29 #include <memory>
30 #include <mutex>
31
32 namespace dmlc {
33 // enable registry
34 DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry);
35 } // namespace dmlc
36
37 namespace tvm {
38 namespace relay {
39
Registry()40 ::dmlc::Registry<OpRegistry>* OpRegistry::Registry() {
41 return ::dmlc::Registry<OpRegistry>::Get();
42 }
43
44 // single manager of operator information.
45 struct OpManager {
46 // mutex to avoid registration from multiple threads.
47 std::mutex mutex;
48 // global operator counter
49 std::atomic<int> op_counter{0};
50 // storage of additional attribute table.
51 std::unordered_map<std::string, std::unique_ptr<GenericOpMap>> attr;
52 // frontend functions
53 std::vector<PackedFunc*> frontend_funcs;
54 // get singleton of the op manager
Globaltvm::relay::OpManager55 static OpManager* Global() {
56 static OpManager* inst = new OpManager();
57 return inst;
58 }
59 };
60
61 // find operator by name
Get(const std::string & name)62 const Op& Op::Get(const std::string& name) {
63 const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
64 CHECK(reg != nullptr) << "Operator " << name << " is not registered";
65 return reg->op();
66 }
67
OpRegistry()68 OpRegistry::OpRegistry() {
69 OpManager* mgr = OpManager::Global();
70 NodePtr<OpNode> n = make_node<OpNode>();
71 n->index_ = mgr->op_counter++;
72 op_ = Op(n);
73 }
74
75 // Get attribute map by key
GetGenericAttr(const std::string & key)76 const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
77 OpManager* mgr = OpManager::Global();
78 std::lock_guard<std::mutex> lock(mgr->mutex);
79 auto it = mgr->attr.find(key);
80 if (it == mgr->attr.end()) {
81 LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered";
82 }
83 return *it->second.get();
84 }
85
86 // Check if a key is present in the registry.
HasGenericAttr(const std::string & key)87 const bool Op::HasGenericAttr(const std::string& key) {
88 OpManager* mgr = OpManager::Global();
89 std::lock_guard<std::mutex> lock(mgr->mutex);
90 auto it = mgr->attr.find(key);
91 if (it == mgr->attr.end()) {
92 return false;
93 }
94 return true;
95 }
96
97 // Resets attr of the OpMap.
reset_attr(const std::string & key)98 void OpRegistry::reset_attr(const std::string& key) {
99 OpManager* mgr = OpManager::Global();
100 std::lock_guard<std::mutex> lock(mgr->mutex);
101 std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
102 if (op_map == nullptr) {
103 return;
104 }
105 uint32_t index = op_->index_;
106 if (op_map->data_.size() > index) {
107 op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
108 }
109 }
110
UpdateAttr(const std::string & key,TVMRetValue value,int plevel)111 void OpRegistry::UpdateAttr(const std::string& key,
112 TVMRetValue value,
113 int plevel) {
114 OpManager* mgr = OpManager::Global();
115 std::lock_guard<std::mutex> lock(mgr->mutex);
116 std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
117 if (op_map == nullptr) {
118 op_map.reset(new GenericOpMap());
119 op_map->attr_name_ = key;
120 }
121 uint32_t index = op_->index_;
122 if (op_map->data_.size() <= index) {
123 op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
124 }
125 std::pair<TVMRetValue, int>& p = op_map->data_[index];
126 CHECK(p.second != plevel)
127 << "Attribute " << key << " of operator " << this->name
128 << " is already registered with same plevel=" << plevel;
129 CHECK(value.type_code() != kNull)
130 << "Registered packed_func is Null for " << key
131 << " of operator " << this->name;
132 if (p.second < plevel && value.type_code() != kNull) {
133 op_map->data_[index] = std::make_pair(value, plevel);
134 }
135 }
136
137 // Frontend APIs
138 TVM_REGISTER_API("relay.op._ListOpNames")
__anond588af5b0102() 139 .set_body_typed<Array<tvm::Expr>()>([]() {
140 Array<tvm::Expr> ret;
141 for (const std::string& name :
142 dmlc::Registry<OpRegistry>::ListAllNames()) {
143 ret.push_back(tvm::Expr(name));
144 }
145 return ret;
146 });
147
148 TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get);
149
150 TVM_REGISTER_API("relay.op._OpGetAttr")
__anond588af5b0202(TVMArgs args, TVMRetValue* rv) 151 .set_body([](TVMArgs args, TVMRetValue* rv) {
152 Op op = args[0];
153 std::string attr_name = args[1];
154 auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
155 if (op_map.count(op)) {
156 *rv = op_map[op];
157 }
158 });
159
160 TVM_REGISTER_API("relay.op._OpSetAttr")
__anond588af5b0302(TVMArgs args, TVMRetValue* rv) 161 .set_body([](TVMArgs args, TVMRetValue* rv) {
162 Op op = args[0];
163 std::string attr_name = args[1];
164 runtime::TVMArgValue value = args[2];
165 int plevel = args[3];
166 auto& reg =
167 OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
168 reg.set_attr(attr_name, value, plevel);
169 });
170
171 TVM_REGISTER_API("relay.op._OpResetAttr")
__anond588af5b0402(TVMArgs args, TVMRetValue* rv) 172 .set_body([](TVMArgs args, TVMRetValue* rv) {
173 Op op = args[0];
174 std::string attr_name = args[1];
175 auto& reg =
176 OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
177 reg.reset_attr(attr_name);
178 });
179
180 TVM_REGISTER_API("relay.op._Register")
__anond588af5b0502(TVMArgs args, TVMRetValue* rv) 181 .set_body([](TVMArgs args, TVMRetValue* rv) {
182 std::string op_name = args[0];
183 std::string attr_key = args[1];
184 runtime::TVMArgValue value = args[2];
185 int plevel = args[3];
186 auto& reg =
187 OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
188 // enable resgiteration and override of certain properties
189 if (attr_key == "num_inputs" && plevel > 128) {
190 reg.set_num_inputs(value);
191 } else if (attr_key == "attrs_type_key" && plevel > 128) {
192 LOG(FATAL) << "attrs type key no longer supported";
193 } else {
194 // normal attr table override.
195 if (args[2].type_code() == kFuncHandle) {
196 // do an eager copy of the PackedFunc
197 PackedFunc f = args[2];
198 // If we get a function from frontend, avoid deleting it.
199 OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
200 reg.set_attr(attr_key, f, plevel);
201 } else {
202 reg.set_attr(attr_key, args[2], plevel);
203 }
204 }
205 });
206
207 // helper to get internal dev function in objectref.
208 struct Op2NodePtr : public ObjectRef {
Gettvm::relay::Op2NodePtr209 static NodePtr<Node> Get(const Op& op) {
210 return GetDataPtr<Node>(op);
211 }
212 };
213
CreateOp(const std::string & name)214 NodePtr<Node> CreateOp(const std::string& name) {
215 // Hack use TVMRetValue as exchange
216 auto op = Op::Get(name);
217 CHECK(op.defined()) << "Cannot find op \'" << name << '\'';
218 return Op2NodePtr::Get(op);
219 }
220
221 TVM_REGISTER_NODE_TYPE(OpNode)
222 .set_creator(CreateOp)
__anond588af5b0602(const Object* n) 223 .set_global_key([](const Object* n) {
224 return static_cast<const OpNode*>(n)->name;
225 });
226
227 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anond588af5b0702(const ObjectRef& ref, IRPrinter* p) 228 .set_dispatch<OpNode>([](const ObjectRef& ref, IRPrinter* p) {
229 auto* node = static_cast<const OpNode*>(ref.get());
230 p->stream << "Op(" << node->name << ")";
231 });
232
233 } // namespace relay
234 } // namespace tvm
235