1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file op.cc
22  * \brief Support for operator registry.
23  */
24 #include <nnvm/base.h>
25 #include <nnvm/op.h>
26 
27 #include <memory>
28 #include <atomic>
29 #include <mutex>
30 #include <unordered_set>
31 
32 namespace dmlc {
33 // enable registry
34 DMLC_REGISTRY_ENABLE(nnvm::Op);
35 }  // namespace dmlc
36 
37 namespace nnvm {
38 
39 // single manager of operator information.
40 struct OpManager {
41   // mutex to avoid registration from multiple threads.
42   // recursive is needed for trigger(which calls UpdateAttrMap)
43   std::recursive_mutex mutex;
44   // global operator counter
45   std::atomic<int> op_counter{0};
46   // storage of additional attribute table.
47   std::unordered_map<std::string, std::unique_ptr<any> > attr;
48   // storage of existing triggers
49   std::unordered_map<std::string, std::vector<std::function<void(Op*)>  > > tmap;
50   // group of each operator.
51   std::vector<std::unordered_set<std::string> > op_group;
52   // get singleton of the
Globalnnvm::OpManager53   static OpManager* Global() {
54     static OpManager inst;
55     return &inst;
56   }
57 };
58 
59 // constructor
Op()60 Op::Op() {
61   OpManager* mgr = OpManager::Global();
62   index_ = mgr->op_counter++;
63 }
64 
add_alias(const std::string & alias)65 Op& Op::add_alias(const std::string& alias) {  // NOLINT(*)
66   dmlc::Registry<Op>::Get()->AddAlias(this->name, alias);
67   return *this;
68 }
69 
70 // find operator by name
Get(const std::string & name)71 const Op* Op::Get(const std::string& name) {
72   const Op* op = dmlc::Registry<Op>::Find(name);
73   CHECK(op != nullptr)
74       << "Operator " << name << " is not registered";
75   return op;
76 }
77 
78 // Get attribute map by key
GetAttrMap(const std::string & key)79 const any* Op::GetAttrMap(const std::string& key) {
80   auto& dict =  OpManager::Global()->attr;
81   auto it = dict.find(key);
82   if (it != dict.end()) {
83     return it->second.get();
84   } else {
85     return nullptr;
86   }
87 }
88 
89 // update attribute map
UpdateAttrMap(const std::string & key,std::function<void (any *)> updater)90 void Op::UpdateAttrMap(const std::string& key,
91                        std::function<void(any*)> updater) {
92   OpManager* mgr = OpManager::Global();
93   std::lock_guard<std::recursive_mutex>(mgr->mutex);
94   std::unique_ptr<any>& value = mgr->attr[key];
95   if (value.get() == nullptr) value.reset(new any());
96   if (updater != nullptr) updater(value.get());
97 }
98 
AddGroupTrigger(const std::string & group_name,std::function<void (Op *)> trigger)99 void Op::AddGroupTrigger(const std::string& group_name,
100                          std::function<void(Op*)> trigger) {
101   OpManager* mgr = OpManager::Global();
102   std::lock_guard<std::recursive_mutex>(mgr->mutex);
103   auto& tvec = mgr->tmap[group_name];
104   tvec.push_back(trigger);
105   auto& op_group = mgr->op_group;
106   for (const Op* op : dmlc::Registry<Op>::List()) {
107     if (op->index_ < op_group.size() &&
108         op_group[op->index_].count(group_name) != 0) {
109       trigger((Op*)op);  // NOLINT(*)
110     }
111   }
112 }
113 
include(const std::string & group_name)114 Op& Op::include(const std::string& group_name) {
115   OpManager* mgr = OpManager::Global();
116   std::lock_guard<std::recursive_mutex>(mgr->mutex);
117   auto it = mgr->tmap.find(group_name);
118   if (it != mgr->tmap.end()) {
119     for (auto& trigger : it->second) {
120       trigger(this);
121     }
122   }
123   auto& op_group = mgr->op_group;
124   if (index_ >= op_group.size()) {
125     op_group.resize(index_ + 1);
126   }
127   op_group[index_].insert(group_name);
128   return *this;
129 }
130 
131 }  // namespace nnvm
132