1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "glow/Importer/ProtobufLoader.h"
18 #include "llvm/Support/CommandLine.h"
19 #include <string>
20 
21 namespace glow {
22 
23 llvm::cl::OptionCategory loaderOptCat("Model Loader Options");
24 
25 static llvm::cl::opt<bool> isConstFoldLoaderOps(
26     "const-fold-ops",
27     llvm::cl::desc(
28         "Performs constant folding on ONNX and Caffe Operators while loading."),
29     llvm::cl::init(true), llvm::cl::cat(loaderOptCat));
30 
isArrayConstant(llvm::ArrayRef<size_t> a)31 bool isArrayConstant(llvm::ArrayRef<size_t> a) {
32   for (size_t i = 1; i < a.size(); i++)
33     if (a[0] != a[i])
34       return false;
35   return true;
36 }
37 
setConstantFoldLoaderOpsFlag(bool flag)38 void setConstantFoldLoaderOpsFlag(bool flag) { isConstFoldLoaderOps = flag; }
39 
getConstantFoldLoaderOpsFlag()40 bool getConstantFoldLoaderOpsFlag() { return isConstFoldLoaderOps; }
41 
isConstantFoldable(llvm::ArrayRef<NodeValue> inputs,std::string typeName) const42 bool ProtobufLoader::isConstantFoldable(llvm::ArrayRef<NodeValue> inputs,
43                                         std::string typeName) const {
44   int numInputs = inputs.size();
45   if (!getConstantFoldLoaderOpsFlag()) {
46     return false;
47   }
48   // foldUnsupportedTypes: List of typenames unsupported for folding.
49   std::string foldUnsupportedTypes[] = {"Constant"};
50   std::string *findType = std::find(std::begin(foldUnsupportedTypes),
51                                     std::end(foldUnsupportedTypes), typeName);
52   // Early exit if folding is not supported for current operator.
53   if (findType != std::end(foldUnsupportedTypes)) {
54     return false;
55   }
56 
57   // If all the inputs to the operator are constant this op can be folded.
58   for (int i = 0; i < numInputs; i++) {
59     if (inputs[i].getNode()->getKind() != Kinded::Kind::ConstantKind) {
60       return false;
61     }
62   }
63   return true;
64 }
65 
66 Placeholder *
getStaticPlaceholderByNameOrNull(llvm::StringRef name) const67 ProtobufLoader::getStaticPlaceholderByNameOrNull(llvm::StringRef name) const {
68   auto it = nodeValueByName_.find(name);
69   if (it == nodeValueByName_.end()) {
70     return nullptr;
71   }
72   auto *res = llvm::dyn_cast<Placeholder>(it->second.getNode());
73   return (res && res->isStatic()) ? res : nullptr;
74 }
75 
getConstantByNameOrNull(llvm::StringRef name) const76 Constant *ProtobufLoader::getConstantByNameOrNull(llvm::StringRef name) const {
77   auto it = nodeValueByName_.find(name);
78   if (it == nodeValueByName_.end()) {
79     return nullptr;
80   }
81   auto *res = llvm::dyn_cast<Constant>(it->second.getNode());
82   return res ? res : nullptr;
83 }
84 
85 Expected<Constant *>
getConstantByName(llvm::StringRef name) const86 ProtobufLoader::getConstantByName(llvm::StringRef name) const {
87   auto *ptr = getConstantByNameOrNull(name);
88   RETURN_ERR_IF_NOT(
89       ptr, strFormat("could not find constant with name %s", name.data()));
90   return ptr;
91 }
92 
hasConstantByName(llvm::StringRef name) const93 bool ProtobufLoader::hasConstantByName(llvm::StringRef name) const {
94   return getConstantByNameOrNull(name) != nullptr;
95 }
96 
getSingleOutput() const97 Expected<Placeholder *> ProtobufLoader::getSingleOutput() const {
98   RETURN_ERR_IF_NOT(outputVarsByName_.size() == 1,
99                     "There must be only one output.");
100   return outputVarsByName_.begin()->second;
101 }
102 
getSingleInput() const103 Expected<Placeholder *> ProtobufLoader::getSingleInput() const {
104   RETURN_ERR_IF_NOT(inputVarsByName_.size() == 1,
105                     "There must be only one input.");
106   return inputVarsByName_.begin()->second;
107 }
108 
109 Expected<Placeholder *>
getOutputByName(llvm::StringRef name) const110 ProtobufLoader::getOutputByName(llvm::StringRef name) const {
111   auto it = outputVarsByName_.find(name);
112   RETURN_ERR_IF_NOT(
113       it != outputVarsByName_.end(),
114       llvm::Twine("No external output Variable was registered with name ", name)
115           .str());
116   return it->second;
117 }
118 
119 Expected<Placeholder *>
getInputByName(llvm::StringRef name) const120 ProtobufLoader::getInputByName(llvm::StringRef name) const {
121   auto it = inputVarsByName_.find(name);
122   RETURN_ERR_IF_NOT(
123       it != inputVarsByName_.end(),
124       llvm::Twine("No external input Variable was registered with name ", name)
125           .str());
126   return it->second;
127 }
128 
129 NodeValue
getNodeValueByNameOrNullNodeValue(llvm::StringRef name,bool ignoreSrcFun)130 ProtobufLoader::getNodeValueByNameOrNullNodeValue(llvm::StringRef name,
131                                                   bool ignoreSrcFun) {
132   auto it = nodeValueByName_.find(name);
133   if (it == nodeValueByName_.end()) {
134     return NodeValue(nullptr);
135   }
136 
137   // Always return the NV of a storage Node since Storage lives in the Module
138   // and is accessible to any Node.
139   NodeValue NV = it->second;
140   if (llvm::isa<Storage>(NV)) {
141     return NV;
142   }
143 
144   // Check if the current Function G_ we are loading into is the same as the
145   // Function of the NV we found; if so then return it.
146   Function *srcF = NV.getNode()->getParent();
147   if (srcF == G_ || ignoreSrcFun) {
148     return NV;
149   }
150 
151   // Otherwise we must be looking up a NV from a different Function in the
152   // Module, so look for an intermediate Placeholder linking the two if it
153   // exists, or otherwise create one and remember it.
154   assert(partNameToFun_.size() > 0 &&
155          "Must be loading a pre-partitioned model.");
156   auto itPH = intermediatePHsByName_.find(name);
157   Placeholder *intermedPH = nullptr;
158   // Create the intermediate PH and SaveNode if it does not yet exist. Note that
159   // we store these intermediate PHs separately from nodeValueByName_ because we
160   // want future users from the same Function as the NV to still use the Node
161   // directly through nodeValueByName_.
162   if (itPH == intermediatePHsByName_.end()) {
163     auto *save = srcF->createSave("tmp_" + NV.getNode()->getName().str(), NV);
164     intermedPH = save->getPlaceholder();
165     intermediatePHsByName_[name] = intermedPH;
166   } else {
167     intermedPH = itPH->second;
168   }
169   return intermedPH->getOutput();
170 }
171 
getNodeValueByName(llvm::StringRef name,bool ignoreSrcFun)172 Expected<NodeValue> ProtobufLoader::getNodeValueByName(llvm::StringRef name,
173                                                        bool ignoreSrcFun) {
174   RETURN_ERR_IF_NOT(hasNodeByName(name),
175                     llvm::Twine("No node under name ", name).str());
176   auto node = getNodeValueByNameOrNullNodeValue(name, ignoreSrcFun);
177   RETURN_ERR_IF_NOT(node.getNode(), "Null is under that name??");
178   return node;
179 }
180 
createAndRegisterConstant(llvm::StringRef name,Tensor && tensor,const std::string & layout)181 Error ProtobufLoader::createAndRegisterConstant(llvm::StringRef name,
182                                                 Tensor &&tensor,
183                                                 const std::string &layout) {
184   auto it = nodeValueByName_.find(name);
185   if (it != nodeValueByName_.end()) {
186     if (llvm::dyn_cast<Placeholder>(it->second.getNode())) {
187       // Placeholders take precedence over Constants.
188       return Error::success();
189     }
190   }
191   // Note: We do not support training from models loaded from protos, so
192   // trainable is always set to false here.
193   Constant *node = mod_.createConstant(name, std::move(tensor), layout);
194   nodeValueByName_[name] = node->getOutput();
195   return Error::success();
196 }
197 
deleteUnusedConstants()198 void ProtobufLoader::deleteUnusedConstants() {
199   std::vector<std::string> nodeValuesToRemove;
200   for (auto &kv : nodeValueByName_) {
201     auto *node = kv.second.getNode();
202     if (auto *c = llvm::dyn_cast<Constant>(node)) {
203       if (!c->hasUsers()) {
204         nodeValuesToRemove.push_back(kv.getKey());
205       }
206     }
207   }
208 
209   for (auto &name : nodeValuesToRemove) {
210     auto it = nodeValueByName_.find(name);
211     auto *c = llvm::dyn_cast<Constant>(it->second.getNode());
212     DCHECK(c) << "NodeValue with name " << name
213               << " was expected to have been a Constant";
214     mod_.eraseConstant(c);
215     nodeValueByName_.erase(it);
216   }
217 }
218 
219 Expected<Placeholder *>
createAndRegisterPlaceholder(llvm::StringRef name,TypeRef T,bool isStatic,bool isTrainable,const std::string & layout)220 ProtobufLoader::createAndRegisterPlaceholder(llvm::StringRef name, TypeRef T,
221                                              bool isStatic, bool isTrainable,
222                                              const std::string &layout) {
223   RETURN_ERR_IF_NOT(
224       !hasNodeByName(name),
225       llvm::Twine("Creating an already existing node ", name).str());
226   RETURN_ERR_IF_NOT(!mod_.hasStorageName(name),
227                     strFormat("A Placeholder was already registered by name %s",
228                               name.data()));
229 
230   Placeholder *node = mod_.createPlaceholder(T, name, isTrainable, layout);
231   node->setStatic(isStatic);
232   nodeValueByName_[name] = node->getOutput();
233   return node;
234 }
235 
hasNodeByName(llvm::StringRef name) const236 bool ProtobufLoader::hasNodeByName(llvm::StringRef name) const {
237   return nodeValueByName_.find(name) != nodeValueByName_.end();
238 }
239 
ProtobufLoader(llvm::ArrayRef<const char * > tensorNames,llvm::ArrayRef<TypeRef> types,Module & mod,Error * errPtr,bool loadIntoExistingModule)240 ProtobufLoader::ProtobufLoader(llvm::ArrayRef<const char *> tensorNames,
241                                llvm::ArrayRef<TypeRef> types, Module &mod,
242                                Error *errPtr, bool loadIntoExistingModule)
243     : G_(nullptr), mod_(mod), loadIntoExistingModule_(loadIntoExistingModule) {
244   setupLoader(tensorNames, types, errPtr);
245 }
246 
ProtobufLoader(llvm::ArrayRef<const char * > tensorNames,llvm::ArrayRef<TypeRef> types,Function * F,Error * errPtr,bool loadIntoExistingModule)247 ProtobufLoader::ProtobufLoader(llvm::ArrayRef<const char *> tensorNames,
248                                llvm::ArrayRef<TypeRef> types, Function *F,
249                                Error *errPtr, bool loadIntoExistingModule)
250     : G_(F), mod_(*F->getParent()),
251       loadIntoExistingModule_(loadIntoExistingModule) {
252   setupLoader(tensorNames, types, errPtr);
253 }
254 
setupLoader(llvm::ArrayRef<const char * > tensorNames,llvm::ArrayRef<TypeRef> types,Error * errPtr)255 void ProtobufLoader::setupLoader(llvm::ArrayRef<const char *> tensorNames,
256                                  llvm::ArrayRef<TypeRef> types, Error *errPtr) {
257   // Verify that the version of the library that we linked against is
258   // compatible with the version of the headers we compiled against.
259   GOOGLE_PROTOBUF_VERIFY_VERSION;
260 
261   // if errPtr already contains an error then don't continue with constructor
262   if (errPtr && *errPtr) {
263     return;
264   }
265 
266   // Use the global flag as default. This may be overridden by instantiations of
267   // the loader later on.
268   constFoldInLoader_ = getConstantFoldLoaderOpsFlag();
269 
270   // Lambda to setup the ProtobufLoader and return any Errors that were
271   // raised.
272   auto setup = [&]() -> Error {
273     RETURN_ERR_IF_NOT(tensorNames.size() == types.size(),
274                       "Invalid initialization list");
275     for (size_t i = 0, e = tensorNames.size(); i < e; i++) {
276       RETURN_ERR_IF_NOT(!hasNodeByName(tensorNames[i]),
277                         "Input names have duplicate");
278       Placeholder *placeholder;
279       ASSIGN_VALUE_OR_RETURN_ERR(
280           placeholder, createAndRegisterPlaceholder(tensorNames[i], types[i]));
281       inputVarsByName_.try_emplace(tensorNames[i], placeholder);
282     }
283     return Error::success();
284   };
285 
286   if (errPtr) {
287     *errPtr = setup();
288   } else {
289     EXIT_ON_ERR(setup());
290   }
291 }
292 
293 }; // namespace glow
294