1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file op.cc
22 * \brief Support for operator registry.
23 */
24 #include <nnvm/base.h>
25 #include <nnvm/op.h>
26
27 #include <memory>
28 #include <atomic>
29 #include <mutex>
30 #include <unordered_set>
31
32 namespace dmlc {
33 // enable registry
34 DMLC_REGISTRY_ENABLE(nnvm::Op);
35 } // namespace dmlc
36
37 namespace nnvm {
38
39 // single manager of operator information.
40 struct OpManager {
41 // mutex to avoid registration from multiple threads.
42 // recursive is needed for trigger(which calls UpdateAttrMap)
43 std::recursive_mutex mutex;
44 // global operator counter
45 std::atomic<int> op_counter{0};
46 // storage of additional attribute table.
47 std::unordered_map<std::string, std::unique_ptr<any> > attr;
48 // storage of existing triggers
49 std::unordered_map<std::string, std::vector<std::function<void(Op*)> > > tmap;
50 // group of each operator.
51 std::vector<std::unordered_set<std::string> > op_group;
52 // get singleton of the
Globalnnvm::OpManager53 static OpManager* Global() {
54 static OpManager inst;
55 return &inst;
56 }
57 };
58
59 // constructor
Op()60 Op::Op() {
61 OpManager* mgr = OpManager::Global();
62 index_ = mgr->op_counter++;
63 }
64
add_alias(const std::string & alias)65 Op& Op::add_alias(const std::string& alias) { // NOLINT(*)
66 dmlc::Registry<Op>::Get()->AddAlias(this->name, alias);
67 return *this;
68 }
69
70 // find operator by name
Get(const std::string & name)71 const Op* Op::Get(const std::string& name) {
72 const Op* op = dmlc::Registry<Op>::Find(name);
73 CHECK(op != nullptr)
74 << "Operator " << name << " is not registered";
75 return op;
76 }
77
78 // Get attribute map by key
GetAttrMap(const std::string & key)79 const any* Op::GetAttrMap(const std::string& key) {
80 auto& dict = OpManager::Global()->attr;
81 auto it = dict.find(key);
82 if (it != dict.end()) {
83 return it->second.get();
84 } else {
85 return nullptr;
86 }
87 }
88
89 // update attribute map
UpdateAttrMap(const std::string & key,std::function<void (any *)> updater)90 void Op::UpdateAttrMap(const std::string& key,
91 std::function<void(any*)> updater) {
92 OpManager* mgr = OpManager::Global();
93 std::lock_guard<std::recursive_mutex>(mgr->mutex);
94 std::unique_ptr<any>& value = mgr->attr[key];
95 if (value.get() == nullptr) value.reset(new any());
96 if (updater != nullptr) updater(value.get());
97 }
98
AddGroupTrigger(const std::string & group_name,std::function<void (Op *)> trigger)99 void Op::AddGroupTrigger(const std::string& group_name,
100 std::function<void(Op*)> trigger) {
101 OpManager* mgr = OpManager::Global();
102 std::lock_guard<std::recursive_mutex>(mgr->mutex);
103 auto& tvec = mgr->tmap[group_name];
104 tvec.push_back(trigger);
105 auto& op_group = mgr->op_group;
106 for (const Op* op : dmlc::Registry<Op>::List()) {
107 if (op->index_ < op_group.size() &&
108 op_group[op->index_].count(group_name) != 0) {
109 trigger((Op*)op); // NOLINT(*)
110 }
111 }
112 }
113
include(const std::string & group_name)114 Op& Op::include(const std::string& group_name) {
115 OpManager* mgr = OpManager::Global();
116 std::lock_guard<std::recursive_mutex>(mgr->mutex);
117 auto it = mgr->tmap.find(group_name);
118 if (it != mgr->tmap.end()) {
119 for (auto& trigger : it->second) {
120 trigger(this);
121 }
122 }
123 auto& op_group = mgr->op_group;
124 if (index_ >= op_group.size()) {
125 op_group.resize(index_ + 1);
126 }
127 op_group[index_].insert(group_name);
128 return *this;
129 }
130
131 } // namespace nnvm
132