1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "glow/Importer/ProtobufLoader.h"
18 #include "llvm/Support/CommandLine.h"
19 #include <string>
20
21 namespace glow {
22
23 llvm::cl::OptionCategory loaderOptCat("Model Loader Options");
24
25 static llvm::cl::opt<bool> isConstFoldLoaderOps(
26 "const-fold-ops",
27 llvm::cl::desc(
28 "Performs constant folding on ONNX and Caffe Operators while loading."),
29 llvm::cl::init(true), llvm::cl::cat(loaderOptCat));
30
isArrayConstant(llvm::ArrayRef<size_t> a)31 bool isArrayConstant(llvm::ArrayRef<size_t> a) {
32 for (size_t i = 1; i < a.size(); i++)
33 if (a[0] != a[i])
34 return false;
35 return true;
36 }
37
setConstantFoldLoaderOpsFlag(bool flag)38 void setConstantFoldLoaderOpsFlag(bool flag) { isConstFoldLoaderOps = flag; }
39
getConstantFoldLoaderOpsFlag()40 bool getConstantFoldLoaderOpsFlag() { return isConstFoldLoaderOps; }
41
isConstantFoldable(llvm::ArrayRef<NodeValue> inputs,std::string typeName) const42 bool ProtobufLoader::isConstantFoldable(llvm::ArrayRef<NodeValue> inputs,
43 std::string typeName) const {
44 int numInputs = inputs.size();
45 if (!getConstantFoldLoaderOpsFlag()) {
46 return false;
47 }
48 // foldUnsupportedTypes: List of typenames unsupported for folding.
49 std::string foldUnsupportedTypes[] = {"Constant"};
50 std::string *findType = std::find(std::begin(foldUnsupportedTypes),
51 std::end(foldUnsupportedTypes), typeName);
52 // Early exit if folding is not supported for current operator.
53 if (findType != std::end(foldUnsupportedTypes)) {
54 return false;
55 }
56
57 // If all the inputs to the operator are constant this op can be folded.
58 for (int i = 0; i < numInputs; i++) {
59 if (inputs[i].getNode()->getKind() != Kinded::Kind::ConstantKind) {
60 return false;
61 }
62 }
63 return true;
64 }
65
66 Placeholder *
getStaticPlaceholderByNameOrNull(llvm::StringRef name) const67 ProtobufLoader::getStaticPlaceholderByNameOrNull(llvm::StringRef name) const {
68 auto it = nodeValueByName_.find(name);
69 if (it == nodeValueByName_.end()) {
70 return nullptr;
71 }
72 auto *res = llvm::dyn_cast<Placeholder>(it->second.getNode());
73 return (res && res->isStatic()) ? res : nullptr;
74 }
75
getConstantByNameOrNull(llvm::StringRef name) const76 Constant *ProtobufLoader::getConstantByNameOrNull(llvm::StringRef name) const {
77 auto it = nodeValueByName_.find(name);
78 if (it == nodeValueByName_.end()) {
79 return nullptr;
80 }
81 auto *res = llvm::dyn_cast<Constant>(it->second.getNode());
82 return res ? res : nullptr;
83 }
84
85 Expected<Constant *>
getConstantByName(llvm::StringRef name) const86 ProtobufLoader::getConstantByName(llvm::StringRef name) const {
87 auto *ptr = getConstantByNameOrNull(name);
88 RETURN_ERR_IF_NOT(
89 ptr, strFormat("could not find constant with name %s", name.data()));
90 return ptr;
91 }
92
hasConstantByName(llvm::StringRef name) const93 bool ProtobufLoader::hasConstantByName(llvm::StringRef name) const {
94 return getConstantByNameOrNull(name) != nullptr;
95 }
96
getSingleOutput() const97 Expected<Placeholder *> ProtobufLoader::getSingleOutput() const {
98 RETURN_ERR_IF_NOT(outputVarsByName_.size() == 1,
99 "There must be only one output.");
100 return outputVarsByName_.begin()->second;
101 }
102
getSingleInput() const103 Expected<Placeholder *> ProtobufLoader::getSingleInput() const {
104 RETURN_ERR_IF_NOT(inputVarsByName_.size() == 1,
105 "There must be only one input.");
106 return inputVarsByName_.begin()->second;
107 }
108
109 Expected<Placeholder *>
getOutputByName(llvm::StringRef name) const110 ProtobufLoader::getOutputByName(llvm::StringRef name) const {
111 auto it = outputVarsByName_.find(name);
112 RETURN_ERR_IF_NOT(
113 it != outputVarsByName_.end(),
114 llvm::Twine("No external output Variable was registered with name ", name)
115 .str());
116 return it->second;
117 }
118
119 Expected<Placeholder *>
getInputByName(llvm::StringRef name) const120 ProtobufLoader::getInputByName(llvm::StringRef name) const {
121 auto it = inputVarsByName_.find(name);
122 RETURN_ERR_IF_NOT(
123 it != inputVarsByName_.end(),
124 llvm::Twine("No external input Variable was registered with name ", name)
125 .str());
126 return it->second;
127 }
128
129 NodeValue
getNodeValueByNameOrNullNodeValue(llvm::StringRef name,bool ignoreSrcFun)130 ProtobufLoader::getNodeValueByNameOrNullNodeValue(llvm::StringRef name,
131 bool ignoreSrcFun) {
132 auto it = nodeValueByName_.find(name);
133 if (it == nodeValueByName_.end()) {
134 return NodeValue(nullptr);
135 }
136
137 // Always return the NV of a storage Node since Storage lives in the Module
138 // and is accessible to any Node.
139 NodeValue NV = it->second;
140 if (llvm::isa<Storage>(NV)) {
141 return NV;
142 }
143
144 // Check if the current Function G_ we are loading into is the same as the
145 // Function of the NV we found; if so then return it.
146 Function *srcF = NV.getNode()->getParent();
147 if (srcF == G_ || ignoreSrcFun) {
148 return NV;
149 }
150
151 // Otherwise we must be looking up a NV from a different Function in the
152 // Module, so look for an intermediate Placeholder linking the two if it
153 // exists, or otherwise create one and remember it.
154 assert(partNameToFun_.size() > 0 &&
155 "Must be loading a pre-partitioned model.");
156 auto itPH = intermediatePHsByName_.find(name);
157 Placeholder *intermedPH = nullptr;
158 // Create the intermediate PH and SaveNode if it does not yet exist. Note that
159 // we store these intermediate PHs separately from nodeValueByName_ because we
160 // want future users from the same Function as the NV to still use the Node
161 // directly through nodeValueByName_.
162 if (itPH == intermediatePHsByName_.end()) {
163 auto *save = srcF->createSave("tmp_" + NV.getNode()->getName().str(), NV);
164 intermedPH = save->getPlaceholder();
165 intermediatePHsByName_[name] = intermedPH;
166 } else {
167 intermedPH = itPH->second;
168 }
169 return intermedPH->getOutput();
170 }
171
getNodeValueByName(llvm::StringRef name,bool ignoreSrcFun)172 Expected<NodeValue> ProtobufLoader::getNodeValueByName(llvm::StringRef name,
173 bool ignoreSrcFun) {
174 RETURN_ERR_IF_NOT(hasNodeByName(name),
175 llvm::Twine("No node under name ", name).str());
176 auto node = getNodeValueByNameOrNullNodeValue(name, ignoreSrcFun);
177 RETURN_ERR_IF_NOT(node.getNode(), "Null is under that name??");
178 return node;
179 }
180
createAndRegisterConstant(llvm::StringRef name,Tensor && tensor,const std::string & layout)181 Error ProtobufLoader::createAndRegisterConstant(llvm::StringRef name,
182 Tensor &&tensor,
183 const std::string &layout) {
184 auto it = nodeValueByName_.find(name);
185 if (it != nodeValueByName_.end()) {
186 if (llvm::dyn_cast<Placeholder>(it->second.getNode())) {
187 // Placeholders take precedence over Constants.
188 return Error::success();
189 }
190 }
191 // Note: We do not support training from models loaded from protos, so
192 // trainable is always set to false here.
193 Constant *node = mod_.createConstant(name, std::move(tensor), layout);
194 nodeValueByName_[name] = node->getOutput();
195 return Error::success();
196 }
197
deleteUnusedConstants()198 void ProtobufLoader::deleteUnusedConstants() {
199 std::vector<std::string> nodeValuesToRemove;
200 for (auto &kv : nodeValueByName_) {
201 auto *node = kv.second.getNode();
202 if (auto *c = llvm::dyn_cast<Constant>(node)) {
203 if (!c->hasUsers()) {
204 nodeValuesToRemove.push_back(kv.getKey());
205 }
206 }
207 }
208
209 for (auto &name : nodeValuesToRemove) {
210 auto it = nodeValueByName_.find(name);
211 auto *c = llvm::dyn_cast<Constant>(it->second.getNode());
212 DCHECK(c) << "NodeValue with name " << name
213 << " was expected to have been a Constant";
214 mod_.eraseConstant(c);
215 nodeValueByName_.erase(it);
216 }
217 }
218
219 Expected<Placeholder *>
createAndRegisterPlaceholder(llvm::StringRef name,TypeRef T,bool isStatic,bool isTrainable,const std::string & layout)220 ProtobufLoader::createAndRegisterPlaceholder(llvm::StringRef name, TypeRef T,
221 bool isStatic, bool isTrainable,
222 const std::string &layout) {
223 RETURN_ERR_IF_NOT(
224 !hasNodeByName(name),
225 llvm::Twine("Creating an already existing node ", name).str());
226 RETURN_ERR_IF_NOT(!mod_.hasStorageName(name),
227 strFormat("A Placeholder was already registered by name %s",
228 name.data()));
229
230 Placeholder *node = mod_.createPlaceholder(T, name, isTrainable, layout);
231 node->setStatic(isStatic);
232 nodeValueByName_[name] = node->getOutput();
233 return node;
234 }
235
hasNodeByName(llvm::StringRef name) const236 bool ProtobufLoader::hasNodeByName(llvm::StringRef name) const {
237 return nodeValueByName_.find(name) != nodeValueByName_.end();
238 }
239
ProtobufLoader(llvm::ArrayRef<const char * > tensorNames,llvm::ArrayRef<TypeRef> types,Module & mod,Error * errPtr,bool loadIntoExistingModule)240 ProtobufLoader::ProtobufLoader(llvm::ArrayRef<const char *> tensorNames,
241 llvm::ArrayRef<TypeRef> types, Module &mod,
242 Error *errPtr, bool loadIntoExistingModule)
243 : G_(nullptr), mod_(mod), loadIntoExistingModule_(loadIntoExistingModule) {
244 setupLoader(tensorNames, types, errPtr);
245 }
246
ProtobufLoader(llvm::ArrayRef<const char * > tensorNames,llvm::ArrayRef<TypeRef> types,Function * F,Error * errPtr,bool loadIntoExistingModule)247 ProtobufLoader::ProtobufLoader(llvm::ArrayRef<const char *> tensorNames,
248 llvm::ArrayRef<TypeRef> types, Function *F,
249 Error *errPtr, bool loadIntoExistingModule)
250 : G_(F), mod_(*F->getParent()),
251 loadIntoExistingModule_(loadIntoExistingModule) {
252 setupLoader(tensorNames, types, errPtr);
253 }
254
setupLoader(llvm::ArrayRef<const char * > tensorNames,llvm::ArrayRef<TypeRef> types,Error * errPtr)255 void ProtobufLoader::setupLoader(llvm::ArrayRef<const char *> tensorNames,
256 llvm::ArrayRef<TypeRef> types, Error *errPtr) {
257 // Verify that the version of the library that we linked against is
258 // compatible with the version of the headers we compiled against.
259 GOOGLE_PROTOBUF_VERIFY_VERSION;
260
261 // if errPtr already contains an error then don't continue with constructor
262 if (errPtr && *errPtr) {
263 return;
264 }
265
266 // Use the global flag as default. This may be overridden by instantiations of
267 // the loader later on.
268 constFoldInLoader_ = getConstantFoldLoaderOpsFlag();
269
270 // Lambda to setup the ProtobufLoader and return any Errors that were
271 // raised.
272 auto setup = [&]() -> Error {
273 RETURN_ERR_IF_NOT(tensorNames.size() == types.size(),
274 "Invalid initialization list");
275 for (size_t i = 0, e = tensorNames.size(); i < e; i++) {
276 RETURN_ERR_IF_NOT(!hasNodeByName(tensorNames[i]),
277 "Input names have duplicate");
278 Placeholder *placeholder;
279 ASSIGN_VALUE_OR_RETURN_ERR(
280 placeholder, createAndRegisterPlaceholder(tensorNames[i], types[i]));
281 inputVarsByName_.try_emplace(tensorNames[i], placeholder);
282 }
283 return Error::success();
284 };
285
286 if (errPtr) {
287 *errPtr = setup();
288 } else {
289 EXIT_ON_ERR(setup());
290 }
291 }
292
293 }; // namespace glow
294