1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/tvm/relay/op.cc
22  * \brief Resolve incomplete types to complete types.
23  */
24 #include <tvm/relay/op.h>
25 #include <tvm/relay/type.h>
26 #include <tvm/runtime/module.h>
27 #include <tvm/runtime/packed_func.h>
28 
29 #include <memory>
30 #include <mutex>
31 
32 namespace dmlc {
33 // enable registry
34 DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry);
35 }  // namespace dmlc
36 
37 namespace tvm {
38 namespace relay {
39 
Registry()40 ::dmlc::Registry<OpRegistry>* OpRegistry::Registry() {
41   return ::dmlc::Registry<OpRegistry>::Get();
42 }
43 
44 // single manager of operator information.
45 struct OpManager {
46   // mutex to avoid registration from multiple threads.
47   std::mutex mutex;
48   // global operator counter
49   std::atomic<int> op_counter{0};
50   // storage of additional attribute table.
51   std::unordered_map<std::string, std::unique_ptr<GenericOpMap>> attr;
52   // frontend functions
53   std::vector<PackedFunc*> frontend_funcs;
54   // get singleton of the op manager
Globaltvm::relay::OpManager55   static OpManager* Global() {
56     static OpManager* inst = new OpManager();
57     return inst;
58   }
59 };
60 
61 // find operator by name
Get(const std::string & name)62 const Op& Op::Get(const std::string& name) {
63   const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
64   CHECK(reg != nullptr) << "Operator " << name << " is not registered";
65   return reg->op();
66 }
67 
OpRegistry()68 OpRegistry::OpRegistry() {
69   OpManager* mgr = OpManager::Global();
70   NodePtr<OpNode> n = make_node<OpNode>();
71   n->index_ = mgr->op_counter++;
72   op_ = Op(n);
73 }
74 
75 // Get attribute map by key
GetGenericAttr(const std::string & key)76 const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
77   OpManager* mgr = OpManager::Global();
78   std::lock_guard<std::mutex> lock(mgr->mutex);
79   auto it = mgr->attr.find(key);
80   if (it == mgr->attr.end()) {
81     LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered";
82   }
83   return *it->second.get();
84 }
85 
86 // Check if a key is present in the registry.
HasGenericAttr(const std::string & key)87 const bool Op::HasGenericAttr(const std::string& key) {
88   OpManager* mgr = OpManager::Global();
89   std::lock_guard<std::mutex> lock(mgr->mutex);
90   auto it = mgr->attr.find(key);
91   if (it == mgr->attr.end()) {
92     return false;
93   }
94   return true;
95 }
96 
97 // Resets attr of the OpMap.
reset_attr(const std::string & key)98 void OpRegistry::reset_attr(const std::string& key) {
99   OpManager* mgr = OpManager::Global();
100   std::lock_guard<std::mutex> lock(mgr->mutex);
101   std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
102   if (op_map == nullptr) {
103     return;
104   }
105   uint32_t index = op_->index_;
106   if (op_map->data_.size() > index) {
107     op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
108   }
109 }
110 
UpdateAttr(const std::string & key,TVMRetValue value,int plevel)111 void OpRegistry::UpdateAttr(const std::string& key,
112                             TVMRetValue value,
113                             int plevel) {
114   OpManager* mgr = OpManager::Global();
115   std::lock_guard<std::mutex> lock(mgr->mutex);
116   std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
117   if (op_map == nullptr) {
118     op_map.reset(new GenericOpMap());
119     op_map->attr_name_ = key;
120   }
121   uint32_t index = op_->index_;
122   if (op_map->data_.size() <= index) {
123     op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
124   }
125   std::pair<TVMRetValue, int>& p = op_map->data_[index];
126   CHECK(p.second != plevel)
127       << "Attribute " << key << " of operator " << this->name
128       << " is already registered with same plevel=" << plevel;
129   CHECK(value.type_code() != kNull)
130       << "Registered packed_func is Null for " << key
131       << " of operator " << this->name;
132   if (p.second < plevel && value.type_code() != kNull) {
133     op_map->data_[index] = std::make_pair(value, plevel);
134   }
135 }
136 
137 // Frontend APIs
138 TVM_REGISTER_API("relay.op._ListOpNames")
__anon9583e5c50102() 139 .set_body_typed<Array<tvm::Expr>()>([]() {
140     Array<tvm::Expr> ret;
141     for (const std::string& name :
142              dmlc::Registry<OpRegistry>::ListAllNames()) {
143       ret.push_back(tvm::Expr(name));
144     }
145     return ret;
146   });
147 
148 TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get);
149 
150 TVM_REGISTER_API("relay.op._OpGetAttr")
__anon9583e5c50202(TVMArgs args, TVMRetValue* rv) 151 .set_body([](TVMArgs args, TVMRetValue* rv) {
152     Op op = args[0];
153     std::string attr_name = args[1];
154     auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
155     if (op_map.count(op)) {
156       *rv = op_map[op];
157     }
158   });
159 
160 TVM_REGISTER_API("relay.op._OpSetAttr")
__anon9583e5c50302(TVMArgs args, TVMRetValue* rv) 161 .set_body([](TVMArgs args, TVMRetValue* rv) {
162     Op op = args[0];
163     std::string attr_name = args[1];
164     runtime::TVMArgValue value = args[2];
165     int plevel = args[3];
166     auto& reg =
167         OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
168     reg.set_attr(attr_name, value, plevel);
169   });
170 
171 TVM_REGISTER_API("relay.op._OpResetAttr")
__anon9583e5c50402(TVMArgs args, TVMRetValue* rv) 172 .set_body([](TVMArgs args, TVMRetValue* rv) {
173     Op op = args[0];
174     std::string attr_name = args[1];
175     auto& reg =
176         OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
177     reg.reset_attr(attr_name);
178   });
179 
180 TVM_REGISTER_API("relay.op._Register")
__anon9583e5c50502(TVMArgs args, TVMRetValue* rv) 181 .set_body([](TVMArgs args, TVMRetValue* rv) {
182     std::string op_name = args[0];
183     std::string attr_key = args[1];
184     runtime::TVMArgValue value = args[2];
185     int plevel = args[3];
186     auto& reg =
187         OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
188     // enable resgiteration and override of certain properties
189     if (attr_key == "num_inputs" && plevel > 128) {
190       reg.set_num_inputs(value);
191     } else if (attr_key == "attrs_type_key" && plevel > 128) {
192       LOG(FATAL) << "attrs type key no longer supported";
193     } else {
194       // normal attr table override.
195       if (args[2].type_code() == kFuncHandle) {
196         // do an eager copy of the PackedFunc
197         PackedFunc f = args[2];
198         // If we get a function from frontend, avoid deleting it.
199         OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
200         reg.set_attr(attr_key, f, plevel);
201       } else {
202         reg.set_attr(attr_key, args[2], plevel);
203       }
204     }
205   });
206 
207 // helper to get internal dev function in objectref.
208 struct Op2NodePtr : public ObjectRef {
Gettvm::relay::Op2NodePtr209   static NodePtr<Node> Get(const Op& op) {
210     return GetDataPtr<Node>(op);
211   }
212 };
213 
CreateOp(const std::string & name)214 NodePtr<Node> CreateOp(const std::string& name) {
215   // Hack use TVMRetValue as exchange
216   auto op = Op::Get(name);
217   CHECK(op.defined()) << "Cannot find op \'" << name << '\'';
218   return Op2NodePtr::Get(op);
219 }
220 
221 TVM_REGISTER_NODE_TYPE(OpNode)
222 .set_creator(CreateOp)
__anon9583e5c50602(const Object* n) 223 .set_global_key([](const Object* n) {
224     return static_cast<const OpNode*>(n)->name;
225   });
226 
227 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon9583e5c50702(const ObjectRef& ref, IRPrinter* p) 228 .set_dispatch<OpNode>([](const ObjectRef& ref, IRPrinter* p) {
229     auto* node = static_cast<const OpNode*>(ref.get());
230     p->stream << "Op(" << node->name << ")";
231   });
232 
233 }  // namespace relay
234 }  // namespace tvm
235