1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "glow/Graph/Graph.h"
17 #include "glow/Backend/Backend.h"
18 #include "glow/Graph/Nodes.h"
19 #include "glow/Graph/PlaceholderBindings.h"
20 #include "glow/Graph/TensorLayout.h"
21 #include "glow/Graph/VerifierHelper.h"
22 #include "glow/Quantization/Base/Base.h"
23 #include "glow/Support/Support.h"
24 
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Format.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 #ifdef WIN32
34 #include <corecrt_math_defines.h>
35 #endif
36 #include <float.h>
37 #include <fstream>
38 #include <unordered_set>
39 
40 using namespace glow;
41 using llvm::cast;
42 using llvm::dyn_cast;
43 using llvm::isa;
44 
45 namespace {
46 /// A helper function to log the deletion of constant/placeholder \p s of a
47 /// module into the log context of given functions \p functions.
48 /// Note: The reason we don't log the deletion of constants in the function that
49 /// ueses or creates it, is that constants/placeholders do not have a function
50 /// parent (we can't utilize its user's function also because its users might be
51 /// removed) such that it's best to log the constants/placeholders in a Module
52 /// level log context and copy over to its all functions.
logStorageDeletion(std::list<Function * > functions,Storage * s)53 void logStorageDeletion(std::list<Function *> functions, Storage *s) {
54   for (auto *F : functions) {
55     F->getLogContext()->logNodeDeletion(*s);
56   }
57   if (functions.size() > 0) {
58     auto *F = *(functions.begin());
59     F->getLogContext()->logNodeDeletion(*s, /* logIntoModule */ true);
60   }
61 }
62 
63 /// A helper function to log the creation of constant/placeholder \p s of a
64 /// module into the log context of given functions \p functions.
65 /// Same note as for logStorageDeletion().
logStorageCreation(std::list<Function * > functions,Storage * s)66 void logStorageCreation(std::list<Function *> functions, Storage *s) {
67   for (auto *F : functions) {
68     F->getLogContext()->logNodeCreation(*s);
69   }
70   if (functions.size() > 0) {
71     auto *F = *(functions.begin());
72     F->getLogContext()->logNodeCreation(*s, /* logIntoModule */ true);
73   }
74 }
75 } // namespace
76 
77 /// Merge shape \p shape into \p mergeShape, following multidirectional
78 /// broadcasting rules.
mergeMultidirectionalBroadcastHelper(std::vector<dim_t> & mergeShape,llvm::ArrayRef<dim_t> shape)79 static void mergeMultidirectionalBroadcastHelper(std::vector<dim_t> &mergeShape,
80                                                  llvm::ArrayRef<dim_t> shape) {
81   size_t shift = mergeShape.size() - shape.size();
82   for (size_t i = 0, e = shape.size(); i < e; i++) {
83     if (shape[i] == 1) {
84       // Just leave mergeShape[i] as it is.
85       continue;
86     }
87 
88     assert(
89         ((shape[i] == mergeShape[shift + i]) || (mergeShape[shift + i] == 1)) &&
90         "Incompatible dimension for the broadcast");
91     mergeShape[shift + i] = shape[i];
92   }
93 }
94 
95 /// Utility function which computes the resulting shape in case of
96 /// multidirectional broadcasting.
97 static std::vector<dim_t>
computeMultidirectionalBroadcastHelper(llvm::ArrayRef<dim_t> shape0,llvm::ArrayRef<dim_t> shape1)98 computeMultidirectionalBroadcastHelper(llvm::ArrayRef<dim_t> shape0,
99                                        llvm::ArrayRef<dim_t> shape1) {
100   size_t numDims0 = shape0.size();
101   size_t numDims1 = shape1.size();
102   size_t newNumDims = std::max(numDims0, numDims1);
103   std::vector<dim_t> reshapeDims(newNumDims, 1);
104 
105   mergeMultidirectionalBroadcastHelper(reshapeDims, shape0);
106   mergeMultidirectionalBroadcastHelper(reshapeDims, shape1);
107 
108   return reshapeDims;
109 }
110 
111 std::vector<NodeValue>
broadcastInputs(int axis,const llvm::ArrayRef<NodeValue> inputs)112 Function::broadcastInputs(int axis, const llvm::ArrayRef<NodeValue> inputs) {
113   dim_t numInputs = inputs.size();
114 
115   if (axis > -1) {
116     assert(
117         numInputs == 2 &&
118         "If axis is specified, not -1, unidirectional broadcast will be used, "
119         "input size must be 2.");
120     return {inputs[0],
121             createBroadcast("broadcast_" + inputs[1].getNode()->getName().str(),
122                             inputs[1], inputs[0].dims(), axis)};
123   }
124 
125   assert(numInputs >= 2 && "Invalid input passed in to commonCreateBroadcast.");
126 
127   std::vector<dim_t> targetDim = computeMultidirectionalBroadcastHelper(
128       inputs[0].dims(), inputs[1].dims());
129 
130   for (size_t i = 2; i < numInputs; ++i) {
131     targetDim =
132         computeMultidirectionalBroadcastHelper(targetDim, inputs[i].dims());
133   }
134 
135   std::vector<NodeValue> out(numInputs);
136   for (size_t i = 0; i < numInputs; ++i) {
137     NodeValue n = inputs[i];
138     auto dims = n.dims();
139     if (dims != llvm::ArrayRef<dim_t>(targetDim)) {
140       unsigned axis = targetDim.size() - dims.size();
141       out[i] = createBroadcast("broadcast_" + n.getNode()->getName().str(), n,
142                                targetDim, axis);
143     } else {
144       out[i] = inputs[i];
145     }
146   }
147   return out;
148 }
149 
hasFunction(llvm::StringRef name)150 bool Module::hasFunction(llvm::StringRef name) { return getFunction(name); }
151 
clearFunctions()152 void Module::clearFunctions() {
153   for (auto *F : functions_) {
154     F->clear();
155   }
156 }
157 
clear()158 void Function::clear() {
159   nodes_.clear();
160   uniqueNodeNames_.clear();
161 }
162 
getFunction(llvm::StringRef name)163 Function *Module::getFunction(llvm::StringRef name) {
164   for (auto *F : functions_) {
165     if (F->getName() == name) {
166       return F;
167     }
168   }
169   return nullptr;
170 }
171 
createFunction(llvm::StringRef name)172 Function *Module::createFunction(llvm::StringRef name) {
173   assert(!hasFunction(name) && "A function with this name already exists");
174   Function *F = new Function(this, name);
175   functions_.push_back(F);
176   return F;
177 }
178 
strip()179 void Module::strip() {
180   for (auto it = constants_.begin(), e = constants_.end(); it != e; it++) {
181     Constant *v = *it;
182     v->clearPayload();
183   }
184 }
185 
clear()186 void Module::clear() {
187   for (auto it = constants_.begin(), e = constants_.end(); it != e; it++) {
188     Constant *v = *it;
189     logStorageDeletion(functions_, v);
190     delete v;
191   }
192 
193   constants_.clear();
194 
195   for (auto it = placeholders_.begin(), e = placeholders_.end(); it != e;
196        it++) {
197     Placeholder *p = *it;
198     logStorageDeletion(functions_, p);
199     delete p;
200   }
201 
202   eraseFunctions();
203 
204   placeholders_.clear();
205 }
206 
~Module()207 Module::~Module() { clear(); }
verify() const208 bool Module::verify() const {
209   bool isValid = true;
210   for (auto *F : functions_) {
211     isValid &= F->verify();
212   }
213   // Check that all types used by constants or placeholders belong to the
214   // module.
215   auto &types = getTypes();
216   for (const auto *PH : getPlaceholders()) {
217     bool foundType =
218         std::find(types.begin(), types.end(), *PH->getType()) != types.end();
219     isValid &=
220         expectCompareTrue("Every type used by placeholders should be part of "
221                           "the graph",
222                           foundType, true, PH);
223   }
224   for (const auto *C : getConstants()) {
225     bool foundType =
226         std::find(types.begin(), types.end(), *C->getType()) != types.end();
227     isValid &=
228         expectCompareTrue("Every type used by constants should be part of "
229                           "the graph",
230                           foundType, true, C);
231   }
232   return isValid;
233 }
234 
dump() const235 void Module::dump() const {
236   llvm::outs() << "Module structure:\n";
237   for (auto *C : getConstants()) {
238     llvm::outs() << C->getDebugDesc() << "\n";
239   }
240 
241   for (auto *P : getPlaceholders()) {
242     llvm::outs() << P->getDebugDesc() << "\n";
243   }
244 
245   for (auto *F : functions_) {
246     llvm::outs() << "Function:" << F->getName() << "\n";
247   }
248 }
249 
toString() const250 std::string Module::toString() const {
251   std::string storage;
252   llvm::raw_string_ostream os(storage);
253   dump(os);
254   return os.str();
255 }
256 
257 /// Creates a std::set copy of \p unsorted, sorted based on name of each
258 /// element, and \returns it.
259 template <class T>
getNamedSorted(const std::list<T * > & unsorted)260 static std::set<T *, SortNamed> getNamedSorted(const std::list<T *> &unsorted) {
261   return std::set<T *, SortNamed>(unsorted.begin(), unsorted.end());
262 }
263 
dump(llvm::raw_ostream & os) const264 void Module::dump(llvm::raw_ostream &os) const {
265   os << "Module structure:\n";
266   for (auto *C : getNamedSorted(constants_)) {
267     os << C->getDebugDesc() << "\n";
268   }
269   for (auto *P : getNamedSorted(placeholders_)) {
270     os << P->getDebugDesc() << "\n";
271   }
272   for (auto *F : getNamedSorted(functions_)) {
273     os << "Function : " << F->getName() << "\n";
274   }
275 }
276 
277 /// A helper class for visiting and generating the dotty graph file.
278 class AbstractDottyPrinter {
279 protected:
280   // List of generated vertices.
281   std::vector<std::string> vertices_{};
282   // List of generated edges.
283   std::unordered_set<std::string> edges_{};
284   // Map node addresses to unique numbers.
285   using VertexNumberMap = std::unordered_map<void *, unsigned>;
286   VertexNumberMap vertex_numbers{};
287 
288   /// Dumps label for a input/output row, given port names.
289   /// E.g. {"LHS", "RHS"} will produce {<LHS>LHS|<RHS>RHS}
dumpLabelForRow(llvm::ArrayRef<std::string> names,std::ostream & os)290   void dumpLabelForRow(llvm::ArrayRef<std::string> names, std::ostream &os) {
291     os << "{";
292     for (size_t i = 0; i < names.size(); i++) {
293       if (i) {
294         os << "|";
295       }
296       os << "<" << names[i] << ">" << names[i];
297     }
298     os << "}";
299   }
300 
dumpLabel(Node * N,std::ostream & os)301   void dumpLabel(Node *N, std::ostream &os) {
302     os << "{";
303     if (N->getNumInputs()) {
304       std::vector<std::string> names(N->getNumInputs());
305       for (size_t i = 0; i < names.size(); i++) {
306         names[i] = N->getInputName(i);
307       }
308       dumpLabelForRow(names, os);
309       os << "|";
310     }
311     os << "{" << escapeDottyString(N->getDebugDesc()) << "}";
312     if (N->getNumResults()) {
313       os << "|";
314       std::vector<std::string> names(N->getNumResults());
315       for (size_t i = 0; i < names.size(); i++) {
316         names[i] = N->getOutputName(i).str();
317       }
318       dumpLabelForRow(names, os);
319     }
320     os << "}";
321   }
322 
dumpNode(Node * N,bool uniqueNames)323   void dumpNode(Node *N, bool uniqueNames) {
324     if (!N) {
325       return;
326     }
327     std::ostringstream os;
328     // Print a node descriptor that looks like this:
329     if (uniqueNames) {
330       // vNNNN [ shape = "record" label = "{...}" ];
331       os << uniqueVertexName(N) << "[\n";
332     } else {
333       // <name> [ shape = "record" label = "{...}" ];
334       os << N->getName().str() << "[\n";
335     }
336     os << "\tlabel = \"";
337     dumpLabel(N, os);
338     os << "\"\n";
339     os << "\tshape = \"record\"\n";
340     os << "\tstyle=\"filled,rounded\"\n";
341 
342     // Pick a color based on the node kind.
343     unsigned colorIdx = llvm::hash_value(llvm::StringRef(N->getKindName()));
344     auto nodeColor = getDotFileNodeColor(colorIdx);
345 
346     if (isa<Constant>(N)) {
347       os << "\tfillcolor=Snow3 color=DeepSkyBlue4\n";
348     } else {
349       os << "\tfillcolor=" << nodeColor << "\n";
350     }
351     os << "penwidth = 2];\n";
352 
353     vertices_.push_back(os.str());
354   }
355 
dumpEdgeStyle(const Node * N,size_t i,Node * to,std::ostream & os)356   void dumpEdgeStyle(const Node *N, size_t i, Node *to, std::ostream &os) {
357     if (N->isOverwrittenNthInput(i)) {
358       os << " [dir=\"both\"]";
359     }
360   }
361 
uniqueVertexName(void * N)362   std::string uniqueVertexName(void *N) {
363     VertexNumberMap::iterator i;
364     bool inserted;
365     std::tie(i, inserted) = vertex_numbers.insert(std::make_pair(N, 0u));
366     if (inserted) {
367       i->second = vertex_numbers.size() - 1;
368     }
369 
370     std::string buffer;
371     llvm::raw_string_ostream stream(buffer);
372     stream << llvm::format("v%04u", i->second);
373     return stream.str();
374   }
375 
376 public:
dumpAll(std::ostream & os)377   void dumpAll(std::ostream &os) {
378     CHECK(os) << "Failed to create file for to dump Graph";
379 
380     os << "digraph DAG {\n\trankdir=TB;\n";
381 
382     // Dump vertices:
383     for (auto &v : vertices_) {
384       os << v << "\n";
385     }
386 
387     // Dump edges:
388     for (auto &e : edges_) {
389       os << e << ";\n";
390     }
391 
392     os << "}";
393   }
394 };
395 
396 class ModuleDottyPrinter : public AbstractDottyPrinter {
397   /// Dump Function as a vertix. Then iterate through constants, used in the
398   /// function, and create corresponding edges.
visitFunction(Function * F)399   void visitFunction(Function *F) {
400     std::ostringstream os;
401     // Print a Function descriptor that looks like this:
402     // vNNNN [ label = "{...}" ];
403     os << uniqueVertexName(F) << "[\n"
404        << "\tlabel = \"Function\\l"
405        << "name : " << F->getName().str() << "\\l"
406        << "node count : " << F->getNodes().size() << "\"\n"
407        << "\tshape = box\n"
408        << "\tfillcolor=gray89, style=\"filled,rounded\"\n"
409        << "\t\n"
410        << "];\n";
411     vertices_.push_back(os.str());
412 
413     for (auto &N : F->getNodes()) {
414       for (size_t i = 0; i < N.getNumInputs(); i++) {
415         Node *to = N.getNthInput(i).getNode();
416         size_t resNo = N.getNthInput(i).getResNo();
417 
418         if (!isa<Constant>(to))
419           continue;
420 
421         std::ostringstream edge;
422         edge << uniqueVertexName(to) << ":" << to->getOutputName(resNo).str()
423              << " -> " << uniqueVertexName(F);
424         dumpEdgeStyle(&N, i, to, edge);
425         edges_.insert(edge.str());
426       }
427     }
428   }
429 
430 public:
visitModule(Module * M)431   void visitModule(Module *M) {
432     for (auto N : M->getConstants()) {
433       dumpNode(N, true);
434     }
435 
436     for (auto F : M->getFunctions()) {
437       visitFunction(F);
438     }
439   }
440 };
441 
442 // TODO: consider refactoring boilerplate code to new trait: DottyPrintable<ADP>
dumpDAG()443 void Module::dumpDAG() {
444   llvm::SmallString<64> dotPath;
445   llvm::sys::fs::createTemporaryFile("dotty_graph_dump", "dot", dotPath);
446   dumpDAG(dotPath);
447 }
448 
dumpDAG(llvm::StringRef dotFilename)449 void Module::dumpDAG(llvm::StringRef dotFilename) {
450   llvm::outs() << "Writing dotty graph for Module to: " << dotFilename << '\n';
451 
452   ModuleDottyPrinter DP;
453 
454   DP.visitModule(this);
455 
456   std::ofstream myfile;
457   myfile.open(dotFilename);
458   DP.dumpAll(myfile);
459   myfile.close();
460 }
461 
dumpDAG(const char * dotFilename)462 void Module::dumpDAG(const char *dotFilename) {
463   dumpDAG(llvm::StringRef(dotFilename));
464 }
465 
eraseFunctions()466 void Module::eraseFunctions() {
467   while (!functions_.empty()) {
468     eraseFunction(*functions_.begin());
469   }
470 }
471 
eraseFunction(Function * F)472 void Module::eraseFunction(Function *F) {
473   auto it = std::find(functions_.begin(), functions_.end(), F);
474   assert(it != functions_.end() && "Function is not part of a module");
475   functions_.erase(it);
476   delete F;
477 }
478 
getConstantsSize()479 uint64_t Module::getConstantsSize() {
480   uint64_t size = 0;
481   for (auto *constant : constants_) {
482     size += constant->getPayload().getSizeInBytes();
483   }
484   return size;
485 }
486 
~Function()487 Function::~Function() {
488   // Delete all of the nodes.
489   for (auto it = nodes_.begin(), e = nodes_.end(); it != e;) {
490     auto cur = it++;
491     eraseNode(&*cur);
492   }
493 }
494 
uniqueType(ElemKind elemTy,llvm::ArrayRef<dim_t> dims)495 TypeRef Module::uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims) {
496   return uniqueType(Type(elemTy, dims));
497 }
498 
uniqueType(ElemKind elemTy,llvm::ArrayRef<dim_t> dims,float scale,int32_t offset)499 TypeRef Module::uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims,
500                            float scale, int32_t offset) {
501   return uniqueType(Type(elemTy, dims, scale, offset));
502 }
503 
uniqueTypeWithNewShape(TypeRef T,llvm::ArrayRef<dim_t> dims)504 TypeRef Module::uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims) {
505   return uniqueType(Type::newShape(*T, dims));
506 }
507 
uniqueTypeWithNewShape(TypeRef T,llvm::ArrayRef<dim_t> dims,llvm::ArrayRef<dim_t> alignments)508 TypeRef Module::uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims,
509                                        llvm::ArrayRef<dim_t> alignments) {
510   return uniqueType(Type::newShape(*T, dims, alignments));
511 }
512 
uniqueTypeWithNewShape(TypeRef T,TypeRef shapeType)513 TypeRef Module::uniqueTypeWithNewShape(TypeRef T, TypeRef shapeType) {
514   return uniqueType(Type::newShape(*T, shapeType));
515 }
516 
uniqueType(const Type & T)517 TypeRef Module::uniqueType(const Type &T) {
518   for (auto &tp : types_) {
519     if (T.isEqual(tp)) {
520       return &tp;
521     }
522   }
523 
524   return &*types_.insert(types_.begin(), T);
525 }
526 
getVoidTy()527 TypeRef Module::getVoidTy() { return uniqueType(Type()); }
528 
529 /// \returns a ShapeVector of rank axes.size() less than the input \p dims,
530 /// where the provided \p axes dimensions are removed from the shape.
getNewShapeWithoutAxes(llvm::ArrayRef<dim_t> dims,llvm::ArrayRef<unsigned_t> axes)531 static ShapeVector getNewShapeWithoutAxes(llvm::ArrayRef<dim_t> dims,
532                                           llvm::ArrayRef<unsigned_t> axes) {
533   assert(axes.size() <= dims.size() &&
534          "Cannot remove more dimensions than exist.");
535   ShapeVector newDims(dims.begin(), dims.end());
536   ShapeVector shapeAxes(axes.begin(), axes.end());
537 
538   // Sort so that looping erase below doesn't fail.
539   std::sort(shapeAxes.rbegin(), shapeAxes.rend());
540 
541   for (const auto &axis : shapeAxes) {
542     assert(axis <= dims.size() &&
543            "Axis to remove must fit inside dimensions of the provided dims.");
544     newDims.erase(newDims.begin() + axis);
545   }
546   return newDims;
547 }
548 
549 //===----------------------------------------------------------------------===//
550 //                       Node builders
551 //===----------------------------------------------------------------------===//
552 
createPlaceholder(TypeRef T,llvm::StringRef name,bool isTrainable,const std::string & layout)553 Placeholder *Module::createPlaceholder(TypeRef T, llvm::StringRef name,
554                                        bool isTrainable,
555                                        const std::string &layout) {
556   auto FT = uniqueType(*T);
557   auto *ph = new Placeholder(name, FT, isTrainable, layout);
558   ph->setName(uniqueName(ph->getName(), usedNodeNames_, usedStorageNames_,
559                          originalNames_));
560   placeholders_.push_back(ph);
561   logStorageCreation(functions_, ph);
562   return ph;
563 }
564 
createPlaceholder(ElemKind T,llvm::ArrayRef<dim_t> dims,llvm::StringRef name,bool isTrainable,const std::string & layout)565 Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims,
566                                        llvm::StringRef name, bool isTrainable,
567                                        const std::string &layout) {
568   auto FT = uniqueType(T, dims);
569   return createPlaceholder(FT, name, isTrainable, layout);
570 }
571 
createPlaceholder(ElemKind T,llvm::ArrayRef<dim_t> dims,float scale,int32_t offset,llvm::StringRef name,bool isTrainable,const std::string & layout)572 Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims,
573                                        float scale, int32_t offset,
574                                        llvm::StringRef name, bool isTrainable,
575                                        const std::string &layout) {
576   auto FT = uniqueType(T, dims, scale, offset);
577   return createPlaceholder(FT, name, isTrainable, layout);
578 }
579 
createConstant(TypeRef T,llvm::StringRef name,const std::string & layout)580 Constant *Module::createConstant(TypeRef T, llvm::StringRef name,
581                                  const std::string &layout) {
582   auto FT = uniqueType(*T);
583   return addConstant(new Constant(name, FT, layout));
584 }
585 
createConstant(ElemKind T,llvm::ArrayRef<dim_t> dims,llvm::StringRef name,const std::string & layout)586 Constant *Module::createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims,
587                                  llvm::StringRef name,
588                                  const std::string &layout) {
589   auto FT = uniqueType(T, dims);
590   return createConstant(FT, name, layout);
591 }
592 
createConstant(ElemKind T,llvm::ArrayRef<dim_t> dims,float scale,int32_t offset,llvm::StringRef name,const std::string & layout)593 Constant *Module::createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims,
594                                  float scale, int32_t offset,
595                                  llvm::StringRef name,
596                                  const std::string &layout) {
597   auto FT = uniqueType(T, dims, scale, offset);
598   return createConstant(FT, name, layout);
599 }
600 
createConstant(llvm::StringRef name,const Tensor & tensor,const std::string & layout)601 Constant *Module::createConstant(llvm::StringRef name, const Tensor &tensor,
602                                  const std::string &layout) {
603   auto *V = createConstant(&tensor.getType(), name, layout);
604   V->assign(&tensor);
605   return V;
606 }
607 
createConstant(llvm::StringRef name,Tensor && tensor,const std::string & layout)608 Constant *Module::createConstant(llvm::StringRef name, Tensor &&tensor,
609                                  const std::string &layout) {
610   return addConstant(new Constant(name, std::move(tensor), layout));
611 }
612 
getPrefix(llvm::StringRef name)613 std::string Module::getPrefix(llvm::StringRef name) {
614   std::string prefix = name;
615   size_t delim = name.rfind("__");
616   if (delim != std::string::npos &&
617       std::all_of(name.begin() + (delim + 2), name.end(),
618                   [](unsigned char c) { return ::isdigit(c); })) {
619     prefix = prefix.substr(0, delim);
620   }
621   return prefix;
622 }
623 
uniqueName(llvm::StringRef name,const llvm::StringSet<> & stringTable,llvm::StringSet<> & updateTable,const llvm::StringSet<> & originalNames)624 llvm::StringRef Module::uniqueName(llvm::StringRef name,
625                                    const llvm::StringSet<> &stringTable,
626                                    llvm::StringSet<> &updateTable,
627                                    const llvm::StringSet<> &originalNames) {
628   std::string legalName = legalizeName(name);
629   if (stringTable.find(legalName) == stringTable.end()) {
630     auto it = updateTable.insert(legalName);
631     if (it.second) {
632       return it.first->first();
633     }
634   }
635   // Retain the trailing "__[0-9]+" if it is in the original name.
636   std::string prefix = (originalNames.find(legalName) == originalNames.end())
637                            ? Module::getPrefix(legalName)
638                            : legalName;
639   for (unsigned i = 1; i < 10000; i++) {
640     auto suffix = std::to_string(i);
641     std::string fullName = prefix + "__" + suffix;
642     if (stringTable.find(fullName) != stringTable.end()) {
643       continue;
644     }
645 
646     auto it = updateTable.insert(fullName);
647     if (it.second) {
648       return it.first->first();
649     }
650   }
651   llvm_unreachable("Unable to find a unique a name.");
652 }
653 
addConstant(Constant * V)654 Constant *Module::addConstant(Constant *V) {
655   V->setName(uniqueName(V->getName(), usedNodeNames_, usedStorageNames_,
656                         originalNames_));
657   // Replace the Constant's output type with the equivalent unique type for
658   // this Module to maintain the invariant that each type in the Module is
659   // unique.
660   V->setType(Constant::ResultIndices::OutputIdx, uniqueType(*V->getType()));
661   constants_.push_back(V);
662   logStorageCreation(functions_, V);
663   return V;
664 }
665 
666 /// Check if the 'pads' array has the right size.
assertPadsSize(NodeValue input,llvm::ArrayRef<int> pads)667 static void assertPadsSize(NodeValue input, llvm::ArrayRef<int> pads) {
668   assert((pads.size() == 2 * input.dims().size()) &&
669          "the pads array must contain 2 values per dimensions");
670 }
671 
createPad(llvm::StringRef name,NodeValue input,TypeRef outTy,unsigned_t mode,llvm::ArrayRef<int> pads,float value)672 PadNode *Function::createPad(llvm::StringRef name, NodeValue input,
673                              TypeRef outTy, unsigned_t mode,
674                              llvm::ArrayRef<int> pads, float value) {
675   assertPadsSize(input, pads);
676   auto OT = getParent()->uniqueType(*outTy);
677   return addNode(new PadNode(name, OT, input, mode, pads, value));
678 }
679 
680 /// Check the kernel size for Conv/Pooling ops.
checkKernelSize(ShapeNHWC idim,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> pads)681 static void checkKernelSize(ShapeNHWC idim, llvm::ArrayRef<unsigned_t> kernels,
682                             llvm::ArrayRef<unsigned_t> pads) {
683   PaddingTLBR pdim(pads);
684   (void)pdim;
685   ShapeHW kdim(kernels);
686   (void)kdim;
687   assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
688          (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
689          "Kernel size is too large");
690 }
691 
692 /// Check the kernel size for 3D Conv/Pooling ops.
check3DKernelSize(ShapeNTHWC idim,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> pads)693 static void check3DKernelSize(ShapeNTHWC idim,
694                               llvm::ArrayRef<unsigned_t> kernels,
695                               llvm::ArrayRef<unsigned_t> pads) {
696   PaddingNFTBLR pdim(pads);
697   (void)pdim;
698   ShapeTHW kdim(kernels);
699   (void)kdim;
700   assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
701          (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
702          (idim.t + pdim.near + pdim.far) >= kdim.temporal_frames &&
703          "Kernel size is too large");
704 }
705 
706 /// Check that the dimensions that are passed in when the ConvTranspose is
707 /// constructed are correct.
assertConvTransposeDims(NodeValue input,NodeValue filter,NodeValue bias,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group)708 static void assertConvTransposeDims(NodeValue input, NodeValue filter,
709                                     NodeValue bias,
710                                     llvm::ArrayRef<unsigned_t> kernels,
711                                     llvm::ArrayRef<unsigned_t> strides,
712                                     llvm::ArrayRef<unsigned_t> pads,
713                                     unsigned_t group) {
714   ShapeNHWC idim = ShapeNHWC(input.dims());
715   (void)idim;
716   ShapeHW kdim(kernels);
717   (void)kdim;
718   assert(idim.c % group == 0 && "channels number must be divisible by groups");
719 
720   // NOTE: here the N in NHWC is abnormal because it is the number of filters
721   // (and therefore the number of output channels of the conv) and not the
722   // batch size. The rest of the dimensions are representative of the input
723   // dimensions to the convolution.
724   ShapeNHWC filterDims(filter.dims());
725   (void)filterDims;
726 
727   assert(filterDims.n % group == 0 && filterDims.h == kdim.height &&
728          filterDims.w == kdim.width && filterDims.c == idim.c / group &&
729          "Invalid filter dims");
730 
731   assert(bias.getType()->size() == filterDims.n && "Invalid bias size");
732 }
733 
734 /// Check that the dimensions that are passed in when the convolution is
735 /// constructed are correct.
assertConvDims(NodeValue input,NodeValue filter,NodeValue bias,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group)736 static void assertConvDims(NodeValue input, NodeValue filter, NodeValue bias,
737                            llvm::ArrayRef<unsigned_t> kernels,
738                            llvm::ArrayRef<unsigned_t> strides,
739                            llvm::ArrayRef<unsigned_t> pads, unsigned_t group) {
740   ShapeNHWC idim = ShapeNHWC(input.dims());
741   ShapeHW kdim(kernels);
742   (void)kdim;
743   checkKernelSize(idim, kernels, pads);
744   assert(idim.c % group == 0 && "channels number must be divisible by groups");
745 
746   // NOTE: here the N in NHWC is abnormal because it is the number of filters
747   // (and therefore the number of output channels of the conv) and not the
748   // batch size. The rest of the dimensions are representative of the input
749   // dimensions to the convolution.
750   ShapeNHWC filterDims(filter.dims());
751   (void)filterDims;
752 
753   assert(filterDims.n % group == 0 && filterDims.h == kdim.height &&
754          filterDims.w == kdim.width && filterDims.c == idim.c / group &&
755          "Invalid filter dims");
756 
757   assert(bias.getType()->size() == filterDims.n && "Invalid bias size");
758 }
759 
760 /// Check that the dimensions that are passed in when the 3D convolution is
761 /// constructed are correct.
assertConv3DDims(NodeValue input,NodeValue filter,NodeValue bias,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group)762 static void assertConv3DDims(NodeValue input, NodeValue filter, NodeValue bias,
763                              llvm::ArrayRef<unsigned_t> kernels,
764                              llvm::ArrayRef<unsigned_t> strides,
765                              llvm::ArrayRef<unsigned_t> pads,
766                              unsigned_t group) {
767   ShapeNTHWC idim(input.dims());
768   ShapeTHW kdim(kernels);
769   (void)kdim;
770   check3DKernelSize(idim, kernels, pads);
771   assert(idim.c % group == 0 && "channels number must be divisible by groups");
772 
773   // NOTE: here the N in NTHWC is abnormal because it is the number of filters
774   // (and therefore the number of output channels of the 3d conv) and not the
775   // batch size. The rest of the dimensions are representative of the input
776   // dimensions to the convolution.
777   ShapeNTHWC filterDims(filter.dims());
778   (void)filterDims;
779 
780   assert(filterDims.n % group == 0 && filterDims.h == kdim.height &&
781          filterDims.w == kdim.width && filterDims.t == kdim.temporal_frames &&
782          filterDims.c == idim.c / group && "Invalid filter dims");
783 
784   assert(bias.getType()->size() == filterDims.n && "Invalid bias size");
785 }
786 
createConv(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group,unsigned_t dilation,ConvolutionLayout layout)787 ConvolutionNode *Function::createConv(
788     llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
789     TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
790     llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
791     unsigned_t group, unsigned_t dilation, ConvolutionLayout layout) {
792   assertConvDims(input, filter, bias, kernels, strides, pads, group);
793   auto OT = getParent()->uniqueType(*outTy);
794 
795   // If the input is quantized but the bias is not then auto-quantize the
796   // bias.
797   if (input.getType()->isQuantizedType()) {
798     auto biasType = bias.getElementType();
799     if (biasType == ElemKind::Int32QTy || biasType == ElemKind::Int8QTy) {
800       // Nothing to do
801     } else if (biasType == ElemKind::FloatTy) {
802       auto biasTy = getParent()->uniqueType(
803           glow::ElemKind::Int32QTy, bias.dims(),
804           input.getType()->getScale() * filter.getType()->getScale(),
805           /* offset */ 0);
806       bias = createQuantize("quantized_bias", bias, biasTy);
807     } else {
808       LOG(DFATAL)
809           << "Unsupported element type for bias of quantized convolution: "
810           << Type::getElementName(biasType).str();
811     }
812   }
813 
814   return addNode(new ConvolutionNode(name, OT, input, filter, bias, kernels,
815                                      strides, pads, group, dilation, layout,
816                                      FusedActivation::NONE));
817 }
818 
createConv(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group,unsigned_t dilation,ConvolutionLayout layout)819 ConvolutionNode *Function::createConv(llvm::StringRef name, NodeValue input,
820                                       NodeValue filter, NodeValue bias,
821                                       TypeRef outTy, unsigned_t kernel,
822                                       unsigned_t stride, unsigned_t pad,
823                                       unsigned_t group, unsigned_t dilation,
824                                       ConvolutionLayout layout) {
825   llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
826   llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
827   llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
828   return createConv(name, input, filter, bias, outTy, kernels, strides, pads,
829                     group, dilation, layout);
830 }
831 
createConv3D(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group)832 Convolution3DNode *Function::createConv3D(llvm::StringRef name, NodeValue input,
833                                           NodeValue filter, NodeValue bias,
834                                           TypeRef outTy,
835                                           llvm::ArrayRef<unsigned_t> kernels,
836                                           llvm::ArrayRef<unsigned_t> strides,
837                                           llvm::ArrayRef<unsigned_t> pads,
838                                           unsigned_t group) {
839   assertConv3DDims(input, filter, bias, kernels, strides, pads, group);
840   auto OT = getParent()->uniqueType(*outTy);
841 
842   // If the input is quantized but the bias is not then auto-quantize the
843   // bias.
844   if (input.getType()->isQuantizedType()) {
845     auto biasType = bias.getElementType();
846     if (biasType == ElemKind::Int32QTy || biasType == ElemKind::Int8QTy ||
847         biasType == ElemKind::Int16QTy) {
848       // Nothing to do
849     } else if (biasType == ElemKind::FloatTy) {
850       auto biasTy = getParent()->uniqueType(
851           glow::ElemKind::Int32QTy, bias.dims(),
852           input.getType()->getScale() * filter.getType()->getScale(),
853           /* offset */ 0);
854       bias = createQuantize("quantized_bias", bias, biasTy);
855     } else {
856       LOG(DFATAL)
857           << "Unsupported element type for bias of quantized convolution: "
858           << Type::getElementName(biasType).str();
859     }
860   }
861   return addNode(new Convolution3DNode(name, OT, input, filter, bias, kernels,
862                                        strides, pads, group));
863 }
864 
createConv3D(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group)865 Convolution3DNode *Function::createConv3D(llvm::StringRef name, NodeValue input,
866                                           NodeValue filter, NodeValue bias,
867                                           TypeRef outTy, unsigned_t kernel,
868                                           unsigned_t stride, unsigned_t pad,
869                                           unsigned_t group) {
870   llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad};
871   llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride};
872   llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel};
873   return createConv3D(name, input, filter, bias, outTy, kernels, strides, pads,
874                       group);
875 }
876 
createConvTranspose(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group,unsigned_t dilation)877 ConvTransposeNode *Function::createConvTranspose(
878     llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
879     TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
880     llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
881     unsigned_t group, unsigned_t dilation) {
882   assertConvTransposeDims(input, filter, bias, kernels, strides, pads, group);
883   auto OT = getParent()->uniqueType(*outTy);
884   return addNode(new ConvTransposeNode(name, OT, input, filter, bias, kernels,
885                                        strides, pads, group, dilation));
886 }
887 
createConvTranspose(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group,unsigned_t dilation)888 ConvTransposeNode *Function::createConvTranspose(
889     llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
890     TypeRef outTy, unsigned_t kernel, unsigned_t stride, unsigned_t pad,
891     unsigned_t group, unsigned_t dilation) {
892   llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
893   llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
894   llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
895   return createConvTranspose(name, input, filter, bias, outTy, kernels, strides,
896                              pads, group, dilation);
897 }
898 
createMaxPool(llvm::StringRef name,NodeValue input,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,ElemKind elemTyAMT,ConvolutionLayout layout)899 MaxPoolNode *Function::createMaxPool(llvm::StringRef name, NodeValue input,
900                                      llvm::ArrayRef<unsigned_t> kernels,
901                                      llvm::ArrayRef<unsigned_t> strides,
902                                      llvm::ArrayRef<unsigned_t> pads,
903                                      ElemKind elemTyAMT,
904                                      ConvolutionLayout layout) {
905   ShapeNHWC idim = ShapeNHWC(input.dims());
906   checkKernelSize(idim, kernels, pads);
907 
908   auto outSz =
909       calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads);
910   auto OT = getParent()->uniqueTypeWithNewShape(
911       input.getType(), {idim.n, outSz.first, outSz.second, idim.c});
912   auto AMT = getParent()->uniqueType(
913       elemTyAMT, {idim.n, outSz.first, outSz.second, idim.c});
914 
915   return addNode(
916       new MaxPoolNode(name, OT, AMT, input, kernels, strides, pads, layout));
917 }
918 
createMaxPool(llvm::StringRef name,NodeValue input,unsigned_t kernel,unsigned_t stride,unsigned_t pad,ElemKind elemTyAMT,ConvolutionLayout layout)919 MaxPoolNode *Function::createMaxPool(llvm::StringRef name, NodeValue input,
920                                      unsigned_t kernel, unsigned_t stride,
921                                      unsigned_t pad, ElemKind elemTyAMT,
922                                      ConvolutionLayout layout) {
923   llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
924   llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
925   llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
926   return createMaxPool(name, input, kernels, strides, pads, elemTyAMT, layout);
927 }
928 
createAvgPool(llvm::StringRef name,NodeValue input,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,ConvolutionLayout layout)929 AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input,
930                                      llvm::ArrayRef<unsigned_t> kernels,
931                                      llvm::ArrayRef<unsigned_t> strides,
932                                      llvm::ArrayRef<unsigned_t> pads,
933                                      ConvolutionLayout layout) {
934   if (!is3DData(layout)) {
935 
936     ShapeNHWC idim = ShapeNHWC(input.dims());
937     checkKernelSize(idim, kernels, pads);
938 
939     auto outSz =
940         calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads);
941     auto OT = getParent()->uniqueTypeWithNewShape(
942         input.getType(), {idim.n, outSz.first, outSz.second, idim.c});
943     return addNode(
944         new AvgPoolNode(name, OT, input, kernels, strides, pads, layout));
945 
946   } else {
947     ShapeNTHWC idim = ShapeNTHWC(input.dims());
948     check3DKernelSize(idim, kernels, pads);
949 
950     auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels,
951                                                strides, pads);
952     auto OT = getParent()->uniqueTypeWithNewShape(
953         input.getType(),
954         {idim.n, outSz.temporal_frames, outSz.height, outSz.width, idim.c});
955     return addNode(
956         new AvgPoolNode(name, OT, input, kernels, strides, pads, layout));
957   }
958 }
959 
createAvgPool(llvm::StringRef name,NodeValue input,TypeRef outTy,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,ConvolutionLayout layout)960 AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input,
961                                      TypeRef outTy,
962                                      llvm::ArrayRef<unsigned_t> kernels,
963                                      llvm::ArrayRef<unsigned_t> strides,
964                                      llvm::ArrayRef<unsigned_t> pads,
965                                      ConvolutionLayout layout) {
966   if (!is3DData(layout)) {
967 
968     ShapeNHWC idim = ShapeNHWC(input.dims());
969     ShapeHW kdim(kernels);
970     (void)kdim;
971     checkKernelSize(idim, kernels, pads);
972     return addNode(
973         new AvgPoolNode(name, outTy, input, kernels, strides, pads, layout));
974 
975   } else {
976 
977     ShapeNTHWC idim = ShapeNTHWC(input.dims());
978     ShapeTHW kdim(kernels);
979     (void)kdim;
980     check3DKernelSize(idim, kernels, pads);
981     return addNode(
982         new AvgPoolNode(name, outTy, input, kernels, strides, pads, layout));
983   }
984 }
985 
createAvgPool(llvm::StringRef name,NodeValue input,unsigned_t kernel,unsigned_t stride,unsigned_t pad,ConvolutionLayout layout)986 AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input,
987                                      unsigned_t kernel, unsigned_t stride,
988                                      unsigned_t pad, ConvolutionLayout layout) {
989   if (!is3DData(layout)) {
990 
991     llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
992     llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
993     llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
994     return createAvgPool(name, input, kernels, strides, pads, layout);
995 
996   } else {
997 
998     llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad};
999     llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride};
1000     llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel};
1001     return createAvgPool(name, input, kernels, strides, pads, layout);
1002   }
1003 }
1004 
createAdaptiveAvgPool(llvm::StringRef name,NodeValue input,TypeRef outTy)1005 AdaptiveAvgPoolNode *Function::createAdaptiveAvgPool(llvm::StringRef name,
1006                                                      NodeValue input,
1007                                                      TypeRef outTy) {
1008   return addNode(new AdaptiveAvgPoolNode(name, outTy, input));
1009 }
1010 
createGemm(llvm::StringRef name,NodeValue A,NodeValue B,NodeValue C,float alpha,float beta,bool transposeA,bool transposeB)1011 GemmNode *Function::createGemm(llvm::StringRef name, NodeValue A, NodeValue B,
1012                                NodeValue C, float alpha, float beta,
1013                                bool transposeA, bool transposeB) {
1014   std::vector<dim_t> outDims(2);
1015   outDims[0] = transposeA ? A.dims()[1] : A.dims()[0];
1016   outDims[1] = transposeB ? B.dims()[0] : B.dims()[1];
1017   TypeRef outTy = getParent()->uniqueTypeWithNewShape(A.getType(), outDims);
1018   return createGemm(name, outTy, A, B, C, alpha, beta, transposeA, transposeB);
1019 }
1020 
createGemm(llvm::StringRef name,TypeRef outTy,NodeValue A,NodeValue B,NodeValue C,float alpha,float beta,bool transposeA,bool transposeB)1021 GemmNode *Function::createGemm(llvm::StringRef name, TypeRef outTy, NodeValue A,
1022                                NodeValue B, NodeValue C, float alpha,
1023                                float beta, bool transposeA, bool transposeB) {
1024   // If C operand is not given then we create a 1D splat with 0.
1025   if (!C.getNode()) {
1026     TypeRef splatTy =
1027         getParent()->uniqueTypeWithNewShape(outTy, {outTy->dims()[1]});
1028     C = createSplat(name.str() + ".SplatC", splatTy, 0.0f);
1029   }
1030   // If C operand is a 2D constant we check if it is a broadcasted version of
1031   // a 1D tensor. If yes then we slice and reshape the C operand to 1D.
1032   if (auto *constC = llvm::dyn_cast<Constant>(C.getNode())) {
1033     if ((constC->dims().size() == 2) && (constC->getPayload().isTiled(0))) {
1034       // Slice and reshape to 1D.
1035       dim_t lengthC = constC->dims()[1];
1036       C = createSlice(name.str() + ".SliceC", C, {0, 0}, {1, lengthC});
1037       C = createReshape(name.str() + ".ReshapeC", C, {lengthC});
1038     }
1039   }
1040   TypeRef OT = getParent()->uniqueType(*outTy);
1041   return addNode(
1042       new GemmNode(name, OT, A, B, C, alpha, beta, transposeA, transposeB));
1043 }
1044 
createFullyConnected(llvm::StringRef name,NodeValue input,Storage * W,Storage * B,unsigned_t axis)1045 FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name,
1046                                                    NodeValue input, Storage *W,
1047                                                    Storage *B,
1048                                                    unsigned_t axis) {
1049   TypeRef T = input.getType();
1050   TypeRef OT = getParent()->uniqueTypeWithNewShape(
1051       T, {input.dims()[0], B->getType()->dims()[0]});
1052 
1053   return createFullyConnected(name, input, W, B, OT, axis);
1054 }
1055 
createFullyConnected(llvm::StringRef name,NodeValue input,NodeValue W,NodeValue B,unsigned_t axis)1056 FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name,
1057                                                    NodeValue input, NodeValue W,
1058                                                    NodeValue B,
1059                                                    unsigned_t axis) {
1060   TypeRef T = input.getType();
1061   TypeRef OT =
1062       getParent()->uniqueTypeWithNewShape(T, {input.dims()[0], B.dims()[0]});
1063 
1064   return createFullyConnected(name, input, W, B, OT, axis);
1065 }
1066 
createFullyConnected(llvm::StringRef name,NodeValue input,NodeValue W,NodeValue B,TypeRef outTy,unsigned_t axis)1067 FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name,
1068                                                    NodeValue input, NodeValue W,
1069                                                    NodeValue B, TypeRef outTy,
1070                                                    unsigned_t axis) {
1071   assert(outTy->dims().size() == 2 && "Invalid number of dimensions");
1072   assert(outTy->dims()[0] == input.dims()[0] && "Invalid dimensions");
1073 
1074   // FC always uses 2D input; flatten if necessary.
1075   if (input.dims().size() != 2) {
1076     input = createFlatten(name.str() + ".reshape2D", input, axis);
1077   }
1078 
1079   TypeRef OT = getParent()->uniqueType(*outTy);
1080   return addNode(new FullyConnectedNode(name, OT, input, W, B));
1081 }
1082 
1083 RowwiseQuantizedFullyConnectedNode *
createRowwiseQuantizedFullyConnected(llvm::StringRef name,NodeValue input,Constant * W,Constant * scales,Constant * offsets,NodeValue B,TypeRef outTy)1084 Function::createRowwiseQuantizedFullyConnected(llvm::StringRef name,
1085                                                NodeValue input, Constant *W,
1086                                                Constant *scales,
1087                                                Constant *offsets, NodeValue B,
1088                                                TypeRef outTy) {
1089   return addNode(new RowwiseQuantizedFullyConnectedNode(name, outTy, input, W,
1090                                                         scales, offsets, B));
1091 }
1092 
1093 RowwiseQuantizedFullyConnectedNode *
createRowwiseQuantizedFullyConnected(llvm::StringRef name,NodeValue input,Constant * W,NodeValue B,TypeRef outTy,quantization::Schema schema,bool transposeWeight)1094 Function::createRowwiseQuantizedFullyConnected(llvm::StringRef name,
1095                                                NodeValue input, Constant *W,
1096                                                NodeValue B, TypeRef outTy,
1097                                                quantization::Schema schema,
1098                                                bool transposeWeight) {
1099   // Since W is constant, quantize it in compilation time.
1100   // The quantized data is in qWeights, the scale of each row is in scales,
1101   // and the offset of each row is in offsets.
1102   Constant *weights = llvm::cast<Constant>(W);
1103   dim_t numRows =
1104       transposeWeight ? W->getType()->dims()[1] : W->getType()->dims()[0];
1105   dim_t numCols =
1106       transposeWeight ? W->getType()->dims()[0] : W->getType()->dims()[1];
1107 
1108   // So far, if we want to create a storage with Int8QTy/Int16QTy,
1109   // it is assumed to be quantized data and the scale and offset should be
1110   // provided. But for rowwise quantization, the scales and offsets are stored
1111   // in vectors separately, we add the dummy scale and offset here.
1112   auto *qWeights = getParent()->createConstant(
1113       ElemKind::Int8QTy, {numRows, numCols}, 0.0, 0, "weights.rwqfc");
1114   auto *scales =
1115       getParent()->createConstant(ElemKind::FloatTy, {numRows}, "scales.rwqfc");
1116   auto *offsets = getParent()->createConstant(ElemKind::Int32ITy, {numRows},
1117                                               "offsets.rwqfc");
1118 
1119   Tensor wt;
1120   if (transposeWeight) {
1121     // This happens when the RowwiseQuantizedFullyConnected node is converted
1122     // from a quantized FullyConnected node in Glow's quantization procedure.
1123     // Since in FC, the weights is stored as transposed (i.e. I * W + B), but
1124     // in RowwiseQuantizedFullyConnected, the weights is stored as it is (i.e.
1125     // I * W(T) + B).
1126     weights->getPayloadMutable().transpose(&wt, {1, 0});
1127   } else {
1128     wt.assign(&(weights->getPayload()));
1129   }
1130 
1131   // Note: Using int32_t offset here as that is what RWQ-FC expects.
1132   quantization::tensorRowwiseQuantization<float, int32_t, int8_t>(
1133       wt, qWeights->getPayloadMutable(), scales->getPayloadMutable(),
1134       offsets->getPayloadMutable(), schema);
1135 
1136   return addNode(new RowwiseQuantizedFullyConnectedNode(
1137       name, outTy, input, qWeights, scales, offsets, B));
1138 }
1139 
createRELU(llvm::StringRef name,NodeValue input,TypeRef outTy)1140 ReluNode *Function::createRELU(llvm::StringRef name, NodeValue input,
1141                                TypeRef outTy) {
1142   return addNode(new ReluNode(name, outTy, input));
1143 }
1144 
createRELU(llvm::StringRef name,NodeValue input)1145 ReluNode *Function::createRELU(llvm::StringRef name, NodeValue input) {
1146   return addNode(new ReluNode(name, input.getType(), input));
1147 }
1148 
createGELU(llvm::StringRef name,NodeValue input)1149 Node *Function::createGELU(llvm::StringRef name, NodeValue input) {
1150   auto outTy = input.getType();
1151 
1152   Node *alphaSplat =
1153       createSplat(name.str() + ".alpha", outTy, M_2_SQRTPI * M_SQRT1_2);
1154   Node *splat = createSplat(name.str() + ".splat", outTy, 0.044715);
1155   Node *splatHalf = createSplat(name.str() + ".splatHalf", outTy, 0.5);
1156   Node *splat1 = createSplat(name.str() + ".splat3", outTy, 1.0);
1157   Node *splat3 = createSplat(name.str() + ".splat3", outTy, 3.0);
1158 
1159   // pow(x, 3)
1160   Node *pow = createPow(name.str() + ".pow", input, splat3);
1161 
1162   // pow(x, 3) * 0.044715
1163   Node *mul = createMul(name.str() + ".mul", pow, splat);
1164 
1165   // x + pow(x, 3) * 0.044715
1166   Node *add = createAdd(name.str() + ".add", input, mul);
1167 
1168   // (x * pow(x, 3) * 0.044715) * alpha
1169   Node *mul2 = createMul(name.str() + ".mul2", add, alphaSplat);
1170 
1171   // tanh((x * pow(x, 3) * 0.044715) * alpha)
1172   Node *tanh = createTanh(name.str() + ".tanh", mul2);
1173 
1174   // tanh((x * pow(x, 3) * 0.044715) * alpha) + 1
1175   Node *add2 = createAdd(name.str() + ".add2", tanh, splat1);
1176 
1177   // (tanh((x * pow(x, 3) * 0.044715) * alpha) + 1) * 0.5
1178   Node *mul3 = createMul(name.str() + ".mul3", splatHalf, add2);
1179 
1180   // (tanh((x * pow(x, 3) * 0.044715) * alpha) + 1) * 0.5 * x
1181   return createMul(name.str() + ".mul4", mul3, input);
1182 }
1183 
createPRELU(llvm::StringRef name,NodeValue input,NodeValue slope,TypeRef outTy)1184 PReluNode *Function::createPRELU(llvm::StringRef name, NodeValue input,
1185                                  NodeValue slope, TypeRef outTy) {
1186   return addNode(new PReluNode(name, outTy, input, slope));
1187 }
1188 
createPRELU(llvm::StringRef name,NodeValue input,NodeValue slope)1189 PReluNode *Function::createPRELU(llvm::StringRef name, NodeValue input,
1190                                  NodeValue slope) {
1191   return addNode(new PReluNode(name, input.getType(), input, slope));
1192 }
1193 
createSigmoid(llvm::StringRef name,TypeRef outTy,NodeValue input)1194 SigmoidNode *Function::createSigmoid(llvm::StringRef name, TypeRef outTy,
1195                                      NodeValue input) {
1196   return addNode(new SigmoidNode(name, outTy, input));
1197 }
1198 
createSigmoid(llvm::StringRef name,NodeValue input)1199 SigmoidNode *Function::createSigmoid(llvm::StringRef name, NodeValue input) {
1200   return createSigmoid(name, input.getType(), input);
1201 }
1202 
createSwish(llvm::StringRef name,NodeValue input)1203 SwishNode *Function::createSwish(llvm::StringRef name, NodeValue input) {
1204   return addNode(new SwishNode(name, input.getType(), input));
1205 }
1206 
createTanh(llvm::StringRef name,TypeRef outTy,NodeValue input)1207 TanhNode *Function::createTanh(llvm::StringRef name, TypeRef outTy,
1208                                NodeValue input) {
1209   return addNode(new TanhNode(name, outTy, input));
1210 }
1211 
createTanh(llvm::StringRef name,NodeValue input)1212 TanhNode *Function::createTanh(llvm::StringRef name, NodeValue input) {
1213   return createTanh(name, input.getType(), input);
1214 }
1215 
createSoftMax(llvm::StringRef name,NodeValue input,NodeValue selected,TypeRef outTy,float beta)1216 SoftMaxNode *Function::createSoftMax(llvm::StringRef name, NodeValue input,
1217                                      NodeValue selected, TypeRef outTy,
1218                                      float beta) {
1219   // Create input multiplier with beta.
1220   if (beta != 1.0) {
1221     auto *splat = createSplat(name, input.getType(), 1);
1222     input = createMul(name, input, splat);
1223   }
1224   // By default, pick the input type.
1225   if (!outTy) {
1226     outTy = getParent()->uniqueType(*input.getType());
1227   }
1228   return addNode(new SoftMaxNode(name, outTy, input, selected));
1229 }
1230 
createCrossEntropyLoss(llvm::StringRef name,NodeValue input,NodeValue labels)1231 CrossEntropyLossNode *Function::createCrossEntropyLoss(llvm::StringRef name,
1232                                                        NodeValue input,
1233                                                        NodeValue labels) {
1234   auto ty = getParent()->uniqueTypeWithNewShape(input.getType(), {1});
1235   return addNode(new CrossEntropyLossNode(name, ty, input, labels));
1236 }
1237 
createRegression(llvm::StringRef name,NodeValue input,NodeValue expected)1238 RegressionNode *Function::createRegression(llvm::StringRef name,
1239                                            NodeValue input,
1240                                            NodeValue expected) {
1241   return addNode(new RegressionNode(name, input, expected));
1242 }
1243 
1244 SigmoidCrossEntropyWithLogitsNode *
createSigmoidCrossEntropyWithLogits(llvm::StringRef name,NodeValue logits,NodeValue targets)1245 Function::createSigmoidCrossEntropyWithLogits(llvm::StringRef name,
1246                                               NodeValue logits,
1247                                               NodeValue targets) {
1248   assert(logits.dims().size() > 1);
1249   std::vector<dim_t> outDims(logits.dims().begin(), logits.dims().end() - 1);
1250   auto ty = getParent()->uniqueTypeWithNewShape(logits.getType(), outDims);
1251   return addNode(
1252       new SigmoidCrossEntropyWithLogitsNode(name, ty, logits, targets));
1253 }
1254 
createReshape(llvm::StringRef name,NodeValue input,llvm::ArrayRef<dim_t> shape,llvm::StringRef layout)1255 ReshapeNode *Function::createReshape(llvm::StringRef name, NodeValue input,
1256                                      llvm::ArrayRef<dim_t> shape,
1257                                      llvm::StringRef layout) {
1258   auto TR = getParent()->uniqueTypeWithNewShape(input.getType(), shape);
1259   DCHECK_EQ(TR->size(), input.getType()->size())
1260       << "Reshape to a different size";
1261   return addNode(new ReshapeNode(name, TR, input, shape.vec(), layout));
1262 }
1263 
createTranspose(llvm::StringRef name,NodeValue input,llvm::ArrayRef<unsigned_t> shuffle,const std::string & layout)1264 TransposeNode *Function::createTranspose(llvm::StringRef name, NodeValue input,
1265                                          llvm::ArrayRef<unsigned_t> shuffle,
1266                                          const std::string &layout) {
1267   ShapeVector shape;
1268   auto dims = input.dims();
1269   for (size_t i = 0; i < dims.size(); i++) {
1270     shape.push_back(dims[shuffle[i]]);
1271   }
1272 
1273   // If the layout is known, check that it matches the shuffle:
1274   auto compareShuffle = [&](const std::vector<unsigned_t> targetShuffle) {
1275     auto shuffleVec = shuffle.vec();
1276     return targetShuffle.size() == dims.size() &&
1277            std::equal(shuffleVec.begin(), shuffleVec.end(),
1278                       targetShuffle.begin());
1279   };
1280 
1281   auto currLayout = layout;
1282   if (currLayout == ANY_LAYOUT) {
1283     // If layout got a default value, change it based on shuffle:
1284     // TODO: remove the shuffle and replace it with layout.
1285     if (compareShuffle(NCHW2NHWC) || compareShuffle(HWCN2NHWC)) {
1286       currLayout = "NHWC";
1287     } else if (compareShuffle(NCTHW2NTHWC)) {
1288       currLayout = "NTHWC";
1289     } else if (compareShuffle(NHWC2NCHW)) {
1290       currLayout = "NCHW";
1291     } else if (compareShuffle(NTHWC2NCTHW)) {
1292       currLayout = "NCTHW";
1293     } else if (compareShuffle(NHWC2HWNC)) {
1294       currLayout = "HWNC";
1295     } else if (compareShuffle(CNHW2NHWC)) {
1296       currLayout = "NHWC";
1297     }
1298   }
1299 
1300   auto NT = getParent()->uniqueTypeWithNewShape(input.getType(), shape);
1301   return addNode(new TransposeNode(name, NT, input, shuffle.vec(), currLayout));
1302 }
1303 
createFlip(llvm::StringRef name,NodeValue input,unsigned_t axis)1304 FlipNode *Function::createFlip(llvm::StringRef name, NodeValue input,
1305                                unsigned_t axis) {
1306   auto OT = getParent()->uniqueType(*input.getType());
1307   return addNode(new FlipNode(name, OT, input, axis));
1308 }
1309 
createBroadcast(llvm::StringRef name,NodeValue input,UnsignedArrayRef newShape,unsigned_t axis)1310 Node *Function::createBroadcast(llvm::StringRef name, NodeValue input,
1311                                 UnsignedArrayRef newShape, unsigned_t axis) {
1312   const auto &origDims = input.dims();
1313 
1314   assert(axis + origDims.size() <= newShape.size() &&
1315          "Axis must fit inside the newShape.");
1316 
1317   // Iterate over the new shape; if the original shape had a dimension here
1318   // (when considering the axis) then verify the dimension either matches the
1319   // new shape (no action taken) or == 1 (broadcast in that direction). Else
1320   // the original shape had no dimensions here (after considering axis), so
1321   // add the new dimension and broadcast in that direction.
1322   dim_t reshapeDims[max_tensor_dimensions];
1323   for (dim_t i = 0; i < newShape.size(); i++) {
1324     if (i >= axis && i < origDims.size() + axis) {
1325       const int origIdx = i - axis;
1326       if (origDims[origIdx] == newShape[i]) {
1327         // Keep original dimensions; they are compatible.
1328         reshapeDims[i] = origDims[origIdx];
1329       } else if (origDims[origIdx] == 1) {
1330         // Will broadcast this dimension to size from newShape.
1331         reshapeDims[i] = 1;
1332       } else {
1333         // Incompatible dimensions for broadcasting
1334         llvm_unreachable("Cannot broadcast with these dimensions.");
1335       }
1336     } else {
1337       // Will broadcast this dimension to size from newShape.
1338       reshapeDims[i] = 1;
1339     }
1340   }
1341 
1342   // Reshape the input node to same number of dimensions as new shape, but
1343   // with 1s in place of to-be-broadcasted dimensions.
1344   Node *currNode =
1345       createReshape(name.str() + ".reshape", input,
1346                     llvm::ArrayRef<dim_t>(reshapeDims, newShape.size()));
1347 
1348   // Create a Tile (which is really a Concat) in each direction that needs to
1349   // be broadcasted.
1350   for (size_t i = 0; i < newShape.size(); i++) {
1351     if (reshapeDims[i] == 1 && newShape[i] != 1) {
1352       currNode = createTile(name.str() + ".tile" + std::to_string(i), currNode,
1353                             newShape[i], i);
1354     }
1355   }
1356 
1357   return currNode;
1358 }
1359 
1360 /// \returns true if \p T1 and T2 has the exact same type except for dimension
1361 /// \p dim.
sameSameShapeExceptDim(TypeRef T1,TypeRef T2,unsigned dim)1362 static bool sameSameShapeExceptDim(TypeRef T1, TypeRef T2, unsigned dim) {
1363   if (T1->getElementType() != T2->getElementType()) {
1364     return false;
1365   }
1366 
1367   auto D1 = T1->dims();
1368   auto D2 = T2->dims();
1369 
1370   if (D1.size() != D2.size()) {
1371     return false;
1372   }
1373 
1374   for (unsigned i = 0, e = D1.size(); i < e; i++) {
1375     // Ignore the dimension \p dim.
1376     if (i == dim) {
1377       continue;
1378     }
1379 
1380     if (D1[i] != D2[i]) {
1381       return false;
1382     }
1383   }
1384 
1385   return true;
1386 }
1387 
createConcat(llvm::StringRef name,llvm::ArrayRef<NodeValue> inputs,unsigned_t dimension)1388 ConcatNode *Function::createConcat(llvm::StringRef name,
1389                                    llvm::ArrayRef<NodeValue> inputs,
1390                                    unsigned_t dimension) {
1391   for (int i = 1, e = inputs.size(); i < e; i++) {
1392     assert(sameSameShapeExceptDim(inputs[i].getType(), inputs[0].getType(),
1393                                   dimension) &&
1394            "Invalid type");
1395     (void)sameSameShapeExceptDim;
1396   }
1397   auto inDim = inputs[0].dims();
1398 
1399   ShapeVector shape(inDim.begin(), inDim.end());
1400 
1401   // We are stacking the tensors along a specific dimension. This means that
1402   // we increase the size of the tensor along this dimension.
1403   shape[dimension] = 0;
1404   for (auto I : inputs) {
1405     shape[dimension] += I.getType()->dims()[dimension];
1406   }
1407 
1408   auto NT = getParent()->uniqueTypeWithNewShape(inputs[0].getType(), shape);
1409   std::vector<NodeValue> ops;
1410   ops.reserve(inputs.size());
1411   for (auto I : inputs) {
1412     ops.emplace_back(I);
1413   }
1414   return addNode(new ConcatNode(name, NT, ops, dimension));
1415 }
1416 
createConcat(llvm::StringRef name,llvm::ArrayRef<NodeValue> inputs,unsigned_t dimension,TypeRef outTy)1417 ConcatNode *Function::createConcat(llvm::StringRef name,
1418                                    llvm::ArrayRef<NodeValue> inputs,
1419                                    unsigned_t dimension, TypeRef outTy) {
1420   std::vector<NodeValue> ops;
1421   ops.reserve(inputs.size());
1422   for (auto I : inputs) {
1423     ops.emplace_back(I);
1424   }
1425 
1426   TypeRef OT = getParent()->uniqueType(*outTy);
1427   return addNode(new ConcatNode(name, OT, ops, dimension));
1428 }
1429 
createTile(llvm::StringRef name,NodeValue input,unsigned_t tiles,unsigned_t axis,TypeRef outTy)1430 TileNode *Function::createTile(llvm::StringRef name, NodeValue input,
1431                                unsigned_t tiles, unsigned_t axis,
1432                                TypeRef outTy) {
1433   assert(tiles > 0 && "Tiles must be non-zero.");
1434   assert(axis >= 0 && axis < input.dims().size() &&
1435          "Axis must fall in range of source dims.");
1436 
1437   if (outTy == nullptr) {
1438     ShapeVector outShape(input.dims().begin(), input.dims().end());
1439     outShape[axis] *= tiles;
1440     outTy = getParent()->uniqueTypeWithNewShape(input.getType(), outShape);
1441   }
1442 
1443   return addNode(new TileNode(name, outTy, input, tiles, axis));
1444 }
1445 
createInsertTensor(llvm::StringRef name,NodeValue big,NodeValue small,llvm::ArrayRef<dim_t> start,unsigned_t count,unsigned_t axis)1446 InsertTensorNode *Function::createInsertTensor(llvm::StringRef name,
1447                                                NodeValue big, NodeValue small,
1448                                                llvm::ArrayRef<dim_t> start,
1449                                                unsigned_t count,
1450                                                unsigned_t axis) {
1451   return addNode(new InsertTensorNode(name, big, small, start, count, axis));
1452 }
1453 
createSlice(llvm::StringRef name,NodeValue input,llvm::ArrayRef<dim_t> start,TypeRef outTy)1454 SliceNode *Function::createSlice(llvm::StringRef name, NodeValue input,
1455                                  llvm::ArrayRef<dim_t> start, TypeRef outTy) {
1456   assert(input.dims().size() == start.size() &&
1457          "Start and input dims should match");
1458   assert(outTy->dims().size() == start.size() &&
1459          "Output and start dims should match");
1460 
1461   for (unsigned i = 0, e = input.dims().size(); i < e; i++) {
1462     assert(start[i] + outTy->dims()[i] <= input.dims()[i] &&
1463            "Input/Output/Start dims mismatch");
1464   }
1465 
1466   TypeRef OT = getParent()->uniqueType(*outTy);
1467   return addNode(new SliceNode(name, OT, input, start));
1468 }
1469 
createSlice(llvm::StringRef name,NodeValue input,llvm::ArrayRef<dim_t> begin,llvm::ArrayRef<dim_t> end)1470 SliceNode *Function::createSlice(llvm::StringRef name, NodeValue input,
1471                                  llvm::ArrayRef<dim_t> begin,
1472                                  llvm::ArrayRef<dim_t> end) {
1473   std::vector<dim_t> beginV, shape;
1474   auto dims = input.dims();
1475   assert(begin.size() == end.size() && "Begin and End dimensions should match");
1476   assert(begin.size() == dims.size() &&
1477          "Begin and Input dimensions should match");
1478   for (unsigned i = 0; i < dims.size(); i++) {
1479     dim_t beginI = begin[i];
1480     dim_t endI = end[i];
1481     dim_t dimI = dims[i];
1482     (void)dimI;
1483     assert(beginI >= 0 && "Illegal Begin indices");
1484     assert(endI > 0 && "Illegal End indices");
1485     assert(beginI < dimI && "Illegal Begin indices");
1486     assert(endI <= dimI && "Illegal End indices");
1487     assert(endI > beginI && "Illegal Begin and End indices");
1488     beginV.push_back(beginI);
1489     shape.push_back(endI - beginI);
1490   }
1491 
1492   auto NT = getParent()->uniqueTypeWithNewShape(input.getType(), shape);
1493   return addNode(new SliceNode(name, NT, input, beginV));
1494 }
1495 
createChannelShuffle(llvm::StringRef name,NodeValue input,size_t group,size_t kernel)1496 Node *Function::createChannelShuffle(llvm::StringRef name, NodeValue input,
1497                                      size_t group, size_t kernel) {
1498   return addNode(
1499       new ChannelShuffleNode(name, input.getType(), input, group, kernel));
1500 }
1501 
createSqueeze(llvm::StringRef name,NodeValue input,llvm::ArrayRef<dim_t> axes)1502 ReshapeNode *Function::createSqueeze(llvm::StringRef name, NodeValue input,
1503                                      llvm::ArrayRef<dim_t> axes) {
1504   assert(!axes.empty() && "Parameter `axes` must be provided.");
1505 
1506   ShapeVector shapeAxes(axes.begin(), axes.end());
1507 
1508   // Sort and unique the values in axes to
1509   // 1. make sure each dim is only removed once;
1510   // 2. check if the size and value of dimensions to squeeze are valid.
1511   std::sort(shapeAxes.begin(), shapeAxes.end());
1512   shapeAxes.erase(std::unique(shapeAxes.begin(), shapeAxes.end()),
1513                   shapeAxes.end());
1514   auto inDims = input.dims();
1515   assert(shapeAxes.back() < inDims.size() && "The size and value of dimensions "
1516                                              "to squeeze must be less than the "
1517                                              "input size.");
1518 
1519   ShapeVector newDims;
1520   size_t j = 0;
1521   for (size_t i = 0, e = inDims.size(); i < e; i++) {
1522     if (j < shapeAxes.size() && shapeAxes[j] == i) {
1523       assert(inDims[i] == 1 && "The dimension to squeeze must be 1.");
1524       j++;
1525     } else {
1526       newDims.push_back(inDims[i]);
1527     }
1528   }
1529   return createReshape(name.str() + ".reshape", input, newDims);
1530 }
1531 
createExpandDims(llvm::StringRef name,NodeValue input,llvm::ArrayRef<dim_t> axes)1532 ReshapeNode *Function::createExpandDims(llvm::StringRef name, NodeValue input,
1533                                         llvm::ArrayRef<dim_t> axes) {
1534   assert(!axes.empty() && "Parameter `axes` must be provided.");
1535 
1536   // Dimensions provided in axes are for the output tensor, so we sort them
1537   // and unique them to make sure they are processed correctly and in the
1538   // right order.
1539   ShapeVector shapeAxes(axes.begin(), axes.end());
1540   std::sort(shapeAxes.begin(), shapeAxes.end());
1541   shapeAxes.erase(std::unique(shapeAxes.begin(), shapeAxes.end()),
1542                   shapeAxes.end());
1543 
1544   const auto inDims = input.dims();
1545 
1546   // The total number of dimensions in the new shape is equal to the original
1547   // shape size plus the uniqued new shape axes, which represents where to
1548   // insert dimensions of 1 into the output tensor's shape.
1549   const size_t totalNumNewDims = shapeAxes.size() + inDims.size();
1550   assert(totalNumNewDims <= max_tensor_dimensions &&
1551          "New expanded shape has too many dimensions.");
1552   assert(shapeAxes.back() < totalNumNewDims &&
1553          "Specified axis expands outside size of output tensor shape.");
1554   ShapeVector newDims;
1555   for (size_t i = 0, j = 0, k = 0; k < totalNumNewDims; k++) {
1556     if (j < shapeAxes.size() && shapeAxes[j] == k) {
1557       newDims.push_back(1);
1558       j++;
1559     } else {
1560       assert(i < inDims.size() && "Somehow overflowing inDims.");
1561       newDims.push_back(inDims[i]);
1562       i++;
1563     }
1564   }
1565 
1566   // Create a reshape of the original data with the newly determined
1567   // dimensions.
1568   return createReshape(name.str() + ".expanddims", input, newDims);
1569 }
1570 
createFlatten(llvm::StringRef name,NodeValue input,unsigned_t axis)1571 ReshapeNode *Function::createFlatten(llvm::StringRef name, NodeValue input,
1572                                      unsigned_t axis) {
1573   auto xDim = flattenCdr(input.getType()->dims(), axis);
1574   return createReshape(name, input, {xDim.first, xDim.second});
1575 }
1576 
createSplit(llvm::StringRef name,NodeValue input,unsigned_t outputNum,unsigned_t axis,llvm::ArrayRef<dim_t> split,std::vector<SliceNode * > & outputs)1577 void Function::createSplit(llvm::StringRef name, NodeValue input,
1578                            unsigned_t outputNum, unsigned_t axis,
1579                            llvm::ArrayRef<dim_t> split,
1580                            std::vector<SliceNode *> &outputs) {
1581   auto inDims = input.dims();
1582   if (split.empty()) {
1583     assert(inDims[axis] % outputNum == 0 &&
1584            "Dimension to split must be divisible by outputs number.");
1585   } else {
1586     assert(outputNum == split.size() &&
1587            "Number of splits must be divisible by outputs number.");
1588   }
1589 
1590   ShapeVector start(inDims.size(), 0);
1591   ShapeVector end(inDims.begin(), inDims.end());
1592   end[axis] = 0;
1593 
1594   outputs.resize(outputNum);
1595   for (size_t i = 0; i < outputNum; i++) {
1596     size_t curLength = split.empty() ? inDims[axis] / outputNum : split[i];
1597     end[axis] += curLength;
1598     outputs[i] =
1599         createSlice(name.str() + ".out" + std::to_string(i), input, start, end);
1600     start[axis] = end[axis];
1601   }
1602 
1603   assert(end[axis] == inDims[axis] &&
1604          "Total size of results must be equal to input size.");
1605 }
1606 
createBatchNormalization(llvm::StringRef name,NodeValue input,NodeValue beta,NodeValue scale,NodeValue mean,NodeValue var,unsigned_t channelIdx,float epsilon,float momentum)1607 BatchNormalizationNode *Function::createBatchNormalization(
1608     llvm::StringRef name, NodeValue input, NodeValue beta, NodeValue scale,
1609     NodeValue mean, NodeValue var, unsigned_t channelIdx, float epsilon,
1610     float momentum) {
1611   return addNode(new BatchNormalizationNode(name, input, scale, beta, mean, var,
1612                                             channelIdx, epsilon, momentum));
1613 }
1614 
createLayerNormalization(llvm::StringRef name,NodeValue input,NodeValue scale,NodeValue bias,float epsilon)1615 LayerNormalizationNode *Function::createLayerNormalization(llvm::StringRef name,
1616                                                            NodeValue input,
1617                                                            NodeValue scale,
1618                                                            NodeValue bias,
1619                                                            float epsilon) {
1620   return addNode(new LayerNormalizationNode(name, input, scale, bias, epsilon));
1621 }
1622 
createBucketizeNode(llvm::StringRef name,NodeValue input,llvm::ArrayRef<float> boundaries)1623 BucketizeNode *Function::createBucketizeNode(llvm::StringRef name,
1624                                              NodeValue input,
1625                                              llvm::ArrayRef<float> boundaries) {
1626   auto OT = getParent()->uniqueType(ElemKind::Int32ITy, input.dims());
1627   return addNode(new BucketizeNode(name, OT, input, boundaries));
1628 }
1629 
createLocalResponseNormalization(llvm::StringRef name,NodeValue input,unsigned_t halfWindowSize,float alpha,float beta,float k)1630 LocalResponseNormalizationNode *Function::createLocalResponseNormalization(
1631     llvm::StringRef name, NodeValue input, unsigned_t halfWindowSize,
1632     float alpha, float beta, float k) {
1633   // The output tensor is of the same shape as the input tensor.
1634   return addNode(new LocalResponseNormalizationNode(name, input, halfWindowSize,
1635                                                     alpha, beta, k));
1636 }
1637 
createModulo(llvm::StringRef name,NodeValue input,int64_t divisor,bool signFollowDivisor)1638 ModuloNode *Function::createModulo(llvm::StringRef name, NodeValue input,
1639                                    int64_t divisor, bool signFollowDivisor) {
1640   // The output tensor is of the same shape as the input tensor.
1641   auto OT = getParent()->uniqueType(*input.getType());
1642   return addNode(new ModuloNode(name, OT, input, divisor, signFollowDivisor));
1643 }
1644 
createNot(llvm::StringRef name,NodeValue input)1645 NotNode *Function::createNot(llvm::StringRef name, NodeValue input) {
1646   TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, input.dims());
1647   return addNode(new NotNode(name, OT, input));
1648 }
1649 
1650 #define UNARY_ARITHMETIC_FUN_DEF(NODE_NAME_)                                   \
1651   NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name,         \
1652                                                  NodeValue input) {            \
1653     return create##NODE_NAME_(name, input.getType(), input);                   \
1654   }                                                                            \
1655   NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name,         \
1656                                                  TypeRef T, NodeValue input) { \
1657     TypeRef OT = getParent()->uniqueType(*T);                                  \
1658     return addNode(new NODE_NAME_##Node(name, OT, input));                     \
1659   }
1660 UNARY_ARITHMETIC_FUN_DEF(Abs)
1661 UNARY_ARITHMETIC_FUN_DEF(Neg)
1662 UNARY_ARITHMETIC_FUN_DEF(Floor)
1663 UNARY_ARITHMETIC_FUN_DEF(Ceil)
1664 UNARY_ARITHMETIC_FUN_DEF(Round)
1665 UNARY_ARITHMETIC_FUN_DEF(Sqrt)
1666 UNARY_ARITHMETIC_FUN_DEF(Rsqrt)
1667 UNARY_ARITHMETIC_FUN_DEF(Reciprocal)
1668 UNARY_ARITHMETIC_FUN_DEF(Sin)
1669 UNARY_ARITHMETIC_FUN_DEF(Cos)
1670 #undef UNARY_ARITHMETIC_FUN_DEF
1671 
1672 #define ARITHMETIC_FUN_DEF(NODE_NAME_)                                         \
1673   NODE_NAME_##Node *Function::create##NODE_NAME_(                              \
1674       llvm::StringRef name, NodeValue LHS, NodeValue RHS) {                    \
1675     return create##NODE_NAME_(name, LHS.getType(), LHS, RHS);                  \
1676   }                                                                            \
1677   NODE_NAME_##Node *Function::create##NODE_NAME_(                              \
1678       llvm::StringRef name, TypeRef T, NodeValue LHS, NodeValue RHS) {         \
1679     DCHECK(LHS.dims() == RHS.dims())                                           \
1680         << "Invalid operand shapes " << LHS.dims() << " vs " << RHS.dims();    \
1681     TypeRef OT = getParent()->uniqueType(*T);                                  \
1682     return addNode(new NODE_NAME_##Node(name, OT, LHS, RHS));                  \
1683   }
1684 ARITHMETIC_FUN_DEF(Add);
1685 ARITHMETIC_FUN_DEF(Mul);
1686 ARITHMETIC_FUN_DEF(Sub);
1687 ARITHMETIC_FUN_DEF(Div);
1688 ARITHMETIC_FUN_DEF(Max);
1689 ARITHMETIC_FUN_DEF(Min);
1690 ARITHMETIC_FUN_DEF(Pow);
1691 #undef ARITHMETIC_FUN_DEF
1692 
createAnd(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1693 AndNode *Function::createAnd(llvm::StringRef name, NodeValue LHS,
1694                              NodeValue RHS) {
1695   assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1696   TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1697   return addNode(new AndNode(name, OT, LHS, RHS));
1698 }
1699 
createOr(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1700 OrNode *Function::createOr(llvm::StringRef name, NodeValue LHS, NodeValue RHS) {
1701   assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1702   TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1703   return addNode(new OrNode(name, OT, LHS, RHS));
1704 }
1705 
createXor(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1706 XorNode *Function::createXor(llvm::StringRef name, NodeValue LHS,
1707                              NodeValue RHS) {
1708   assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1709   TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1710   return addNode(new XorNode(name, OT, LHS, RHS));
1711 }
1712 
createCmpLTE(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1713 CmpLTENode *Function::createCmpLTE(llvm::StringRef name, NodeValue LHS,
1714                                    NodeValue RHS) {
1715   assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1716   TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1717   return addNode(new CmpLTENode(name, OT, LHS, RHS));
1718 }
1719 
createCmpLT(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1720 CmpLTNode *Function::createCmpLT(llvm::StringRef name, NodeValue LHS,
1721                                  NodeValue RHS) {
1722   assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1723   TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1724   return addNode(new CmpLTNode(name, OT, LHS, RHS));
1725 }
1726 
createCmpGTE(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1727 CmpLTENode *Function::createCmpGTE(llvm::StringRef name, NodeValue LHS,
1728                                    NodeValue RHS) {
1729   assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1730   TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1731   return addNode(new CmpLTENode(name, OT, RHS, LHS));
1732 }
1733 
createCmpGT(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1734 CmpLTNode *Function::createCmpGT(llvm::StringRef name, NodeValue LHS,
1735                                  NodeValue RHS) {
1736   assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1737   TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1738   return addNode(new CmpLTNode(name, OT, RHS, LHS));
1739 }
1740 
createCmpEQ(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1741 CmpEQNode *Function::createCmpEQ(llvm::StringRef name, NodeValue LHS,
1742                                  NodeValue RHS) {
1743   assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1744   TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1745   return addNode(new CmpEQNode(name, OT, LHS, RHS));
1746 }
1747 
createCmpNEQ(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1748 CmpNEQNode *Function::createCmpNEQ(llvm::StringRef name, NodeValue LHS,
1749                                    NodeValue RHS) {
1750   assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1751   TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1752   return addNode(new CmpNEQNode(name, OT, LHS, RHS));
1753 }
1754 
createSquare(llvm::StringRef name,NodeValue input)1755 MulNode *Function::createSquare(llvm::StringRef name, NodeValue input) {
1756   return createMul(name, input, input);
1757 }
1758 
createSquare(llvm::StringRef name,TypeRef outTy,NodeValue input)1759 MulNode *Function::createSquare(llvm::StringRef name, TypeRef outTy,
1760                                 NodeValue input) {
1761   return createMul(name, outTy, input, input);
1762 }
1763 
createLeakyRELU(llvm::StringRef name,NodeValue input,float alpha)1764 PReluNode *Function::createLeakyRELU(llvm::StringRef name, NodeValue input,
1765                                      float alpha) {
1766   return createLeakyRELU(name, input.getType(), input, alpha);
1767 }
1768 
createLeakyRELU(llvm::StringRef name,TypeRef outTy,NodeValue input,float alpha)1769 PReluNode *Function::createLeakyRELU(llvm::StringRef name, TypeRef outTy,
1770                                      NodeValue input, float alpha) {
1771   auto splatType = getParent()->uniqueType(*(input.getType()));
1772   SplatNode *splat = createSplat(name.str() + ".alpha", splatType, alpha);
1773   auto OT = getParent()->uniqueType(*outTy);
1774   return createPRELU(name, input, splat, OT);
1775 }
1776 
createIsNaN(llvm::StringRef name,NodeValue input)1777 IsNaNNode *Function::createIsNaN(llvm::StringRef name, NodeValue input) {
1778   TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, input.dims());
1779   return addNode(new IsNaNNode(name, OT, input));
1780 }
1781 
createReplaceNaN(llvm::StringRef name,NodeValue input,float value)1782 ReplaceNaNNode *Function::createReplaceNaN(llvm::StringRef name,
1783                                            NodeValue input, float value) {
1784   return addNode(new ReplaceNaNNode(name, input.getType(), input, value));
1785 }
1786 
createPow(llvm::StringRef name,NodeValue base,float exp)1787 PowNode *Function::createPow(llvm::StringRef name, NodeValue base, float exp) {
1788   auto *SP = createSplat(name, base.getType(), exp);
1789   return createPow(name, base, SP);
1790 }
1791 
createLog(llvm::StringRef name,NodeValue input,TypeRef outTy)1792 LogNode *Function::createLog(llvm::StringRef name, NodeValue input,
1793                              TypeRef outTy) {
1794   return addNode(new LogNode(name, outTy ? outTy : input.getType(), input));
1795 }
1796 
createExp(llvm::StringRef name,NodeValue input)1797 ExpNode *Function::createExp(llvm::StringRef name, NodeValue input) {
1798   return addNode(new ExpNode(name, input.getType(), input));
1799 }
1800 
createExp(llvm::StringRef name,TypeRef outTy,NodeValue input)1801 ExpNode *Function::createExp(llvm::StringRef name, TypeRef outTy,
1802                              NodeValue input) {
1803   return addNode(new ExpNode(name, outTy, input));
1804 }
1805 
createLogit(llvm::StringRef name,NodeValue input,float eps)1806 LogitNode *Function::createLogit(llvm::StringRef name, NodeValue input,
1807                                  float eps) {
1808   return addNode(new LogitNode(name, input.getType(), input, eps));
1809 }
1810 
createSelect(llvm::StringRef name,TypeRef outTy,NodeValue Cond,NodeValue LHS,NodeValue RHS)1811 SelectNode *Function::createSelect(llvm::StringRef name, TypeRef outTy,
1812                                    NodeValue Cond, NodeValue LHS,
1813                                    NodeValue RHS) {
1814   assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1815   assert(LHS.dims() == Cond.dims() && "Invalid operand shapes");
1816   assert(LHS.dims() == outTy->dims() && "Invalid result shape");
1817   auto OT = getParent()->uniqueType(*outTy);
1818   return addNode(new SelectNode(name, OT, Cond, LHS, RHS));
1819 }
1820 
createSelect(llvm::StringRef name,NodeValue Cond,NodeValue LHS,NodeValue RHS)1821 SelectNode *Function::createSelect(llvm::StringRef name, NodeValue Cond,
1822                                    NodeValue LHS, NodeValue RHS) {
1823   auto inDims = LHS.dims();
1824   assert(inDims.size() > 0);
1825   ShapeVector outDims(inDims.begin(), inDims.end());
1826   auto OT = getParent()->uniqueTypeWithNewShape(LHS.getType(), outDims);
1827   return createSelect(name, OT, Cond, LHS, RHS);
1828 }
1829 
createSplat(llvm::StringRef name,TypeRef ty,float value)1830 SplatNode *Function::createSplat(llvm::StringRef name, TypeRef ty,
1831                                  float value) {
1832   return addNode(new SplatNode(name, getParent()->uniqueType(*ty), value));
1833 }
1834 
createTouch(llvm::StringRef name,TypeRef ty)1835 TouchNode *Function::createTouch(llvm::StringRef name, TypeRef ty) {
1836   return addNode(new TouchNode(name, getParent()->uniqueType(*ty)));
1837 }
1838 
createMatMul(llvm::StringRef name,TypeRef outTy,NodeValue lhs,NodeValue rhs)1839 MatMulNode *Function::createMatMul(llvm::StringRef name, TypeRef outTy,
1840                                    NodeValue lhs, NodeValue rhs) {
1841   return addNode(
1842       new MatMulNode(name, getParent()->uniqueType(*outTy), lhs, rhs));
1843 }
1844 
createMatMul(llvm::StringRef name,NodeValue lhs,NodeValue rhs)1845 MatMulNode *Function::createMatMul(llvm::StringRef name, NodeValue lhs,
1846                                    NodeValue rhs) {
1847   auto LT = lhs.getType();
1848   auto RT = rhs.getType();
1849   auto LDims = LT->dims();
1850   auto RDims = RT->dims();
1851   assert(lhs.getType()->getElementType() == rhs.getType()->getElementType());
1852 
1853   auto ty =
1854       getParent()->uniqueTypeWithNewShape(lhs.getType(), {LDims[0], RDims[1]});
1855   return createMatMul(name, ty, lhs, rhs);
1856 }
1857 
createBatchMatMul(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1858 BatchMatMulNode *Function::createBatchMatMul(llvm::StringRef name,
1859                                              NodeValue LHS, NodeValue RHS) {
1860   const size_t numDimsRHS = RHS.dims().size();
1861   assert(LHS.dims().size() == 3 && "LHS must be 3 dimensional.");
1862   assert((numDimsRHS == 2 || numDimsRHS == 3) &&
1863          "RHS must be 2 or 3 dimensional.");
1864 
1865   // If necessary, expand the RHS input to be 3D by adding initial leading
1866   // dim.
1867   if (numDimsRHS == 2) {
1868     RHS = createExpandDims(name.str() + ".reshapeRHS", RHS, {0});
1869   }
1870   // If necessary, Tile the RHS input so it matches the numBatches of LHS.
1871   if (RHS.dims()[0] == 1 && LHS.dims()[0] != 1) {
1872     RHS = createTile(name.str() + ".tileRHS", RHS, LHS.dims()[0], /*axis */ 0);
1873   }
1874 
1875   // LHS = {numBatches, N, M}
1876   // RHS = {numBatches, M, P}
1877   // Result = {numBatches, N, P}
1878   const dim_t numBatches = LHS.dims()[0];
1879   const dim_t N = LHS.dims()[1];
1880   const dim_t M = LHS.dims()[2];
1881   (void)M;
1882   const dim_t P = RHS.dims()[2];
1883   assert((RHS.dims()[0] == numBatches) && "Batch sizes are invalid.");
1884   assert((RHS.dims()[1] == M) && "Batch matmul dimensions are invalid.");
1885 
1886   auto OT =
1887       getParent()->uniqueTypeWithNewShape(LHS.getType(), {numBatches, N, P});
1888   return addNode(new BatchMatMulNode(name, OT, LHS, RHS));
1889 }
1890 
1891 BatchedReduceAddNode *
createBatchedReduceAdd(llvm::StringRef name,TypeRef outTy,NodeValue batch,llvm::ArrayRef<unsigned_t> axes)1892 Function::createBatchedReduceAdd(llvm::StringRef name, TypeRef outTy,
1893                                  NodeValue batch,
1894                                  llvm::ArrayRef<unsigned_t> axes) {
1895   assert(axes.size() == 1 && "Only supporting single reduction for now.");
1896   auto axis = axes[0];
1897 
1898   // Calculate the expected total number of elements in the output tensor
1899   // based on the number of elements in the batch divided by the axis
1900   // dimension.
1901   const size_t outNumElements = batch.getType()->size() / batch.dims()[axis];
1902   (void)outNumElements;
1903   assert(outTy->size() == outNumElements &&
1904          "Incorrect number of elements in the output type.");
1905   auto OT = getParent()->uniqueType(*outTy);
1906   return addNode(new BatchedReduceAddNode(name, OT, batch, axis));
1907 }
1908 
1909 BatchedReduceAddNode *
createBatchedReduceAdd(llvm::StringRef name,NodeValue batch,llvm::ArrayRef<unsigned_t> axes)1910 Function::createBatchedReduceAdd(llvm::StringRef name, NodeValue batch,
1911                                  llvm::ArrayRef<unsigned_t> axes) {
1912   auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
1913   auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims);
1914   return createBatchedReduceAdd(name, OT, batch, axes);
1915 }
1916 
1917 BatchedReduceMeanNode *
createBatchedReduceMean(llvm::StringRef name,TypeRef outTy,NodeValue batch,llvm::ArrayRef<unsigned_t> axes)1918 Function::createBatchedReduceMean(llvm::StringRef name, TypeRef outTy,
1919                                   NodeValue batch,
1920                                   llvm::ArrayRef<unsigned_t> axes) {
1921   auto OT = getParent()->uniqueType(*outTy);
1922   return addNode(new BatchedReduceMeanNode(name, OT, batch, axes));
1923 }
1924 
1925 BatchedReduceMeanNode *
createBatchedReduceMean(llvm::StringRef name,NodeValue batch,llvm::ArrayRef<unsigned_t> axes)1926 Function::createBatchedReduceMean(llvm::StringRef name, NodeValue batch,
1927                                   llvm::ArrayRef<unsigned_t> axes) {
1928   // Create new shape with specified dimensions either reduced or removed.
1929   auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
1930   auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims);
1931   return createBatchedReduceMean(name, OT, batch, axes);
1932 }
1933 
1934 BatchedReduceMinNode *
createBatchedReduceMin(llvm::StringRef name,NodeValue batch,llvm::ArrayRef<unsigned_t> axes)1935 Function::createBatchedReduceMin(llvm::StringRef name, NodeValue batch,
1936                                  llvm::ArrayRef<unsigned_t> axes) {
1937   // Create new shape with specified dimensions either reduced or removed.
1938   auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
1939   auto OT = getParent()->uniqueType(batch.getType()->getElementType(), outDims);
1940   return addNode(new BatchedReduceMinNode(name, OT, batch, axes));
1941 }
1942 
createBatchedAdd(llvm::StringRef name,NodeValue batch,NodeValue slice)1943 BatchedAddNode *Function::createBatchedAdd(llvm::StringRef name,
1944                                            NodeValue batch, NodeValue slice) {
1945   return addNode(new BatchedAddNode(name, batch.getType(), batch, slice));
1946 }
1947 
createBatchedAdd(llvm::StringRef name,TypeRef outTy,NodeValue batch,NodeValue slice)1948 BatchedAddNode *Function::createBatchedAdd(llvm::StringRef name, TypeRef outTy,
1949                                            NodeValue batch, NodeValue slice) {
1950   return addNode(
1951       new BatchedAddNode(name, getParent()->uniqueType(*outTy), batch, slice));
1952 }
1953 
createCumSum(llvm::StringRef name,NodeValue input,bool exclusive,bool reverse)1954 CumSumNode *Function::createCumSum(llvm::StringRef name, NodeValue input,
1955                                    bool exclusive, bool reverse) {
1956   return addNode(
1957       new CumSumNode(name, input.getType(), input, exclusive, reverse));
1958 }
1959 
createLengthsSum(llvm::StringRef name,NodeValue data,NodeValue lengths)1960 LengthsSumNode *Function::createLengthsSum(llvm::StringRef name, NodeValue data,
1961                                            NodeValue lengths) {
1962   ShapeVector outDims(data.dims().begin(), data.dims().end());
1963   outDims[0] = lengths.dims()[0];
1964   auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
1965   return addNode(new LengthsSumNode(name, outTy, data, lengths));
1966 }
1967 
1968 SparseLengthsSumNode *
createSparseLengthsSum(llvm::StringRef name,NodeValue data,NodeValue indices,NodeValue lengths,LengthsMode lengthsMode,float avgLength)1969 Function::createSparseLengthsSum(llvm::StringRef name, NodeValue data,
1970                                  NodeValue indices, NodeValue lengths,
1971                                  LengthsMode lengthsMode, float avgLength) {
1972   auto inDims = data.dims();
1973   ShapeVector outDims(inDims.begin(), inDims.end());
1974   outDims[0] = lengths.dims()[0];
1975   auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
1976   return addNode(new SparseLengthsSumNode(name, outTy, data, indices, lengths,
1977                                           lengthsMode, avgLength));
1978 }
1979 
createSparseLengthsWeightedSum(llvm::StringRef name,NodeValue data,NodeValue weights,NodeValue indices,NodeValue lengths,LengthsMode lengthsMode,float avgLength)1980 SparseLengthsWeightedSumNode *Function::createSparseLengthsWeightedSum(
1981     llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
1982     NodeValue lengths, LengthsMode lengthsMode, float avgLength) {
1983   auto inDims = data.dims();
1984   ShapeVector outDims(inDims.begin(), inDims.end());
1985   outDims[0] = lengths.dims()[0];
1986   auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
1987   return addNode(new SparseLengthsWeightedSumNode(
1988       name, outTy, data, weights, indices, lengths, lengthsMode, avgLength));
1989 }
1990 
createSparseLengthsWeightedSum(llvm::StringRef name,TypeRef outTy,NodeValue data,NodeValue weights,NodeValue indices,NodeValue lengths,LengthsMode lengthsMode,float avgLength)1991 SparseLengthsWeightedSumNode *Function::createSparseLengthsWeightedSum(
1992     llvm::StringRef name, TypeRef outTy, NodeValue data, NodeValue weights,
1993     NodeValue indices, NodeValue lengths, LengthsMode lengthsMode,
1994     float avgLength) {
1995   return addNode(new SparseLengthsWeightedSumNode(
1996       name, outTy, data, weights, indices, lengths, lengthsMode, avgLength));
1997 }
1998 
1999 RowwiseQuantizedSparseLengthsWeightedSumNode *
createRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,Storage * data,Constant * scales,Constant * offsets,NodeValue weights,NodeValue indices,NodeValue lengths,ElemKind precision,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2000 Function::createRowwiseQuantizedSparseLengthsWeightedSum(
2001     llvm::StringRef name, Storage *data, Constant *scales, Constant *offsets,
2002     NodeValue weights, NodeValue indices, NodeValue lengths, ElemKind precision,
2003     bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2004   auto inDims = data->dims();
2005   ShapeVector outDims(inDims.begin(), inDims.end());
2006   outDims[0] = lengths.dims()[0];
2007   auto outTy = getParent()->uniqueType(precision, outDims);
2008   return addNode(new RowwiseQuantizedSparseLengthsWeightedSumNode(
2009       name, outTy, data, scales, offsets, weights, indices, lengths,
2010       useFP16Accumulation, lengthsMode, avgLength));
2011 }
2012 
2013 RowwiseQuantizedSparseLengthsWeightedSumNode *
createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name,Storage * data,Constant * scales,Constant * offsets,NodeValue indices,NodeValue lengths,ElemKind precision,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2014 Function::createRowwiseQuantizedSparseLengthsSum(
2015     llvm::StringRef name, Storage *data, Constant *scales, Constant *offsets,
2016     NodeValue indices, NodeValue lengths, ElemKind precision,
2017     bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2018   auto ty = getParent()->uniqueType(precision, {indices.dims()[0]});
2019   auto ones = createSplat(name.str() + ".ones", ty, 1.0);
2020   return createRowwiseQuantizedSparseLengthsWeightedSum(
2021       name, data, scales, offsets, ones, indices, lengths, precision,
2022       useFP16Accumulation, lengthsMode, avgLength);
2023 }
2024 
2025 /// Helper to create a RowwiseQuantizedSparseLengthsWeightedSumNode in the
2026 /// Function \p F with \p name, using \ data, \p weights, \p indices, and \p
2027 /// lengths as inputs. The provided float data in \p Tensor is rowwise
2028 /// quantized, creating Constants for the rowwise quantized data as well as
2029 /// Scales and Offsets, in the Module containing \p F.
2030 static RowwiseQuantizedSparseLengthsWeightedSumNode *
quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(Function * F,llvm::StringRef name,Tensor & data,NodeValue weights,NodeValue indices,NodeValue lengths,quantization::Schema schema,ElemKind precision,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2031 quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(
2032     Function *F, llvm::StringRef name, Tensor &data, NodeValue weights,
2033     NodeValue indices, NodeValue lengths, quantization::Schema schema,
2034     ElemKind precision, bool useFP16Accumulation, LengthsMode lengthsMode,
2035     float avgLength) {
2036   auto inDims = data.dims();
2037 
2038   // Note: In rwqData, we are using a quantized type, however the scale/offset
2039   // are set to dummy values 0.0/0. This is because the actually used
2040   // scale/offset come from dataScales and dataOffsets.
2041   Constant *rwqData = F->getParent()->createConstant(ElemKind::UInt8QTy, inDims,
2042                                                      0.0, 0, "data");
2043   Constant *dataScales =
2044       F->getParent()->createConstant(precision, {inDims[0]}, "dataScales");
2045   Constant *dataOffsets =
2046       F->getParent()->createConstant(precision, {inDims[0]}, "dataOffsets");
2047 
2048   // Note: Using floating point offset here as that is what RWQ-SLWS expects.
2049   switch (precision) {
2050   case ElemKind::FloatTy:
2051     quantization::tensorRowwiseQuantization<float, float, uint8_t>(
2052         data, rwqData->getPayloadMutable(), dataScales->getPayloadMutable(),
2053         dataOffsets->getPayloadMutable(), schema);
2054     break;
2055   case ElemKind::Float16Ty:
2056     quantization::tensorRowwiseQuantization<float16_t, float16_t, uint8_t>(
2057         data, rwqData->getPayloadMutable(), dataScales->getPayloadMutable(),
2058         dataOffsets->getPayloadMutable(), schema);
2059     break;
2060   default:
2061     LOG(FATAL) << "Unsupported precision for RWQ-SLWS.";
2062   }
2063   return F->createRowwiseQuantizedSparseLengthsWeightedSum(
2064       name, rwqData, dataScales, dataOffsets, weights, indices, lengths,
2065       precision, useFP16Accumulation, lengthsMode, avgLength);
2066 }
2067 
2068 RowwiseQuantizedSparseLengthsWeightedSumNode *
createRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,Tensor & data,NodeValue weights,NodeValue indices,NodeValue lengths,quantization::Schema schema,ElemKind precision,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2069 Function::createRowwiseQuantizedSparseLengthsWeightedSum(
2070     llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
2071     NodeValue lengths, quantization::Schema schema, ElemKind precision,
2072     bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2073   return quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(
2074       this, name, data, weights, indices, lengths, schema, precision,
2075       useFP16Accumulation, lengthsMode, avgLength);
2076 }
2077 
2078 RowwiseQuantizedSparseLengthsWeightedSumNode *
createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name,Tensor & data,NodeValue indices,NodeValue lengths,quantization::Schema schema,ElemKind precision,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2079 Function::createRowwiseQuantizedSparseLengthsSum(
2080     llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
2081     quantization::Schema schema, ElemKind precision, bool useFP16Accumulation,
2082     LengthsMode lengthsMode, float avgLength) {
2083   auto ty = getParent()->uniqueType(precision, {indices.dims()[0]});
2084   auto ones = createSplat(name.str() + ".ones", ty, 1.0);
2085   return quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(
2086       this, name, data, ones, indices, lengths, schema, precision,
2087       useFP16Accumulation, lengthsMode, avgLength);
2088 }
2089 
2090 /// Helper used to get specific output type required for
2091 /// createRowwiseQuantizedSparseLengthsSum,
2092 /// createRowwiseQuantizedSparseLengthsWeightedSum, and
2093 /// EmbeddingBagByteRowwiseOffsets. Function \p F is used to get the specific
2094 /// type, using inputs \p data and \p segmentsDim to compute output dimensions.
2095 static TypeRef
getOutputTypeOfFusedRowwiseQuantizedSLS(Function * F,NodeValue data,llvm::ArrayRef<dim_t> segmentsDim)2096 getOutputTypeOfFusedRowwiseQuantizedSLS(Function *F, NodeValue data,
2097                                         llvm::ArrayRef<dim_t> segmentsDim) {
2098   ShapeVector outDims(data.dims().begin(), data.dims().end());
2099   outDims[0] = segmentsDim[0];
2100   // The output column count is the same as the input column count, but
2101   // without the extra bytes for the fused scale/offset, as the output is not
2102   // fused.
2103   CHECK(isFusedQuantizedElemKind(data.getElementType()))
2104       << "Must use a fused ElemKind for data.";
2105   outDims[1] -= 2 * ((data.getElementType() == ElemKind::UInt8FusedQTy)
2106                          ? sizeof(float)
2107                          : sizeof(float16_t));
2108   // If using 4-bit quantization, then the input data has packed two 4-bit
2109   // elements into one byte, so we need to double the outDims.
2110   if (data.getElementType() == ElemKind::UInt4FusedFP16QTy) {
2111     outDims[1] *= 2;
2112   }
2113   const ElemKind outputK = (data.getElementType() == ElemKind::UInt8FusedQTy)
2114                                ? ElemKind::FloatTy
2115                                : ElemKind::Float16Ty;
2116   return F->getParent()->uniqueType(outputK, outDims);
2117 }
2118 
2119 FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
createFusedRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,NodeValue data,NodeValue weights,NodeValue indices,NodeValue lengths,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2120 Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum(
2121     llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
2122     NodeValue lengths, bool useFP16Accumulation, LengthsMode lengthsMode,
2123     float avgLength) {
2124   auto outTy =
2125       getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, lengths.dims());
2126   return addNode(new FusedRowwiseQuantizedSparseLengthsWeightedSumNode(
2127       name, outTy, data, weights, indices, lengths, useFP16Accumulation,
2128       lengthsMode, avgLength));
2129 }
2130 
2131 FusedRowwiseQuantizedSparseLengthsSumNode *
createFusedRowwiseQuantizedSparseLengthsSum(llvm::StringRef name,Storage * data,NodeValue indices,NodeValue lengths,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2132 Function::createFusedRowwiseQuantizedSparseLengthsSum(
2133     llvm::StringRef name, Storage *data, NodeValue indices, NodeValue lengths,
2134     bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2135   auto outTy =
2136       getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, lengths.dims());
2137   return addNode(new FusedRowwiseQuantizedSparseLengthsSumNode(
2138       name, outTy, data, indices, lengths, useFP16Accumulation, lengthsMode,
2139       avgLength));
2140 }
2141 
2142 /// Helper to get quantized data required for
2143 /// RowwiseQuantizedSparseLengthsWeightedSumNode and
2144 /// RowwiseQuantizedSparseLengthsSumNode. Function \p F uses float Tensor \p
2145 /// data to create a rowwise qunatized Constant \p rwqData, which contains fused
2146 /// scales and offsets.
quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(Function * F,Tensor & data,ElemKind precision)2147 static Constant *quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2148     Function *F, Tensor &data, ElemKind precision) {
2149   // For fused rowwise quantization, we must have a two-dimensional input. If
2150   // passed in a single dimensional data Tensor then add an extra dimension.
2151   const auto fDims = flattenCdr(data.dims());
2152   Tensor fData = data.getUnowned({fDims.first, fDims.second});
2153 
2154   // Note: In rwqData, we are using a quantized type, however the scale/offset
2155   // are set to dummy values 0.0/0. This is because the actually used
2156   // scale/offset are fused inline with each row. Also, we expand the second
2157   // dimension to include space for the scale/offset, each 4 bytes
2158   // (float/int32_t).
2159   switch (precision) {
2160   case ElemKind::UInt8FusedQTy: {
2161     Constant *rwqData = F->getParent()->createConstant(
2162         precision, {fDims.first, fDims.second + 2 * (dim_t)sizeof(float)}, 0.0,
2163         0, "data");
2164     quantization::tensorFusedRowwiseQuantization<float>(
2165         fData, rwqData->getPayloadMutable());
2166     return rwqData;
2167   }
2168   case ElemKind::UInt8FusedFP16QTy: {
2169     Constant *rwqData = F->getParent()->createConstant(
2170         precision, {fDims.first, fDims.second + 2 * (dim_t)sizeof(float16_t)},
2171         0.0, 0, "data");
2172     quantization::tensorFusedRowwiseQuantization<float16_t>(
2173         fData, rwqData->getPayloadMutable());
2174     return rwqData;
2175   }
2176   case ElemKind::UInt4FusedFP16QTy: {
2177     // We pack 4-bit values into bytes, so given the input size in float we
2178     // divide by two and take the ceiling to make sure we have enough space for
2179     // all elements.
2180     const dim_t outerDim =
2181         std::ceil(((float)fDims.second) / 2) + 2 * sizeof(float16_t);
2182     Constant *rwqData = F->getParent()->createConstant(
2183         precision, {fDims.first, outerDim}, 0.0, 0, "data");
2184     quantization::tensorFusedRowwiseQuantization<float16_t>(
2185         fData, rwqData->getPayloadMutable());
2186     return rwqData;
2187   }
2188   default:
2189     llvm_unreachable("Invalid type for FusedRowwiswQuantization.");
2190   }
2191 }
2192 
2193 FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
createFusedRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,Tensor & data,NodeValue weights,NodeValue indices,NodeValue lengths,ElemKind fusedElemKind,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2194 Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum(
2195     llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
2196     NodeValue lengths, ElemKind fusedElemKind, bool useFP16Accumulation,
2197     LengthsMode lengthsMode, float avgLength) {
2198   Constant *rwqData =
2199       quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2200           this, data, fusedElemKind);
2201   return createFusedRowwiseQuantizedSparseLengthsWeightedSum(
2202       name, rwqData, weights, indices, lengths, useFP16Accumulation,
2203       lengthsMode, avgLength);
2204 }
2205 
2206 FusedRowwiseQuantizedSparseLengthsSumNode *
createFusedRowwiseQuantizedSparseLengthsSum(llvm::StringRef name,Tensor & data,NodeValue indices,NodeValue lengths,ElemKind fusedElemKind,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2207 Function::createFusedRowwiseQuantizedSparseLengthsSum(
2208     llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
2209     ElemKind fusedElemKind, bool useFP16Accumulation, LengthsMode lengthsMode,
2210     float avgLength) {
2211   Constant *rwqData =
2212       quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2213           this, data, fusedElemKind);
2214   return this->createFusedRowwiseQuantizedSparseLengthsSum(
2215       name, rwqData, indices, lengths, useFP16Accumulation, lengthsMode,
2216       avgLength);
2217 }
2218 
2219 EmbeddingBagNode *
createEmbeddingBag(llvm::StringRef name,NodeValue data,NodeValue weights,NodeValue indices,NodeValue offsets,bool hasEndOffset,LengthsMode lengthsMode,float avgLength)2220 Function::createEmbeddingBag(llvm::StringRef name, NodeValue data,
2221                              NodeValue weights, NodeValue indices,
2222                              NodeValue offsets, bool hasEndOffset,
2223                              LengthsMode lengthsMode, float avgLength) {
2224   auto inDims = data.dims();
2225   ShapeVector outDims(inDims.begin(), inDims.end());
2226   outDims[0] = hasEndOffset ? offsets.dims()[0] - 1 : offsets.dims()[0];
2227   auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
2228   return addNode(new EmbeddingBagNode(name, outTy, data, weights, indices,
2229                                       offsets, hasEndOffset, lengthsMode,
2230                                       avgLength));
2231 }
2232 
2233 EmbeddingBagByteRowwiseOffsetsNode *
createEmbeddingBagByteRowwiseOffsets(llvm::StringRef name,Tensor & data,NodeValue weights,NodeValue indices,NodeValue offsets,ElemKind fusedElemKind,bool useFP16Accumulation,bool hasEndOffset,LengthsMode lengthsMode,float avgLength)2234 Function::createEmbeddingBagByteRowwiseOffsets(
2235     llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
2236     NodeValue offsets, ElemKind fusedElemKind, bool useFP16Accumulation,
2237     bool hasEndOffset, LengthsMode lengthsMode, float avgLength) {
2238   Constant *rwqData =
2239       quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2240           this, data, fusedElemKind);
2241   return createEmbeddingBagByteRowwiseOffsets(
2242       name, rwqData, weights, indices, offsets, useFP16Accumulation,
2243       hasEndOffset, lengthsMode, avgLength);
2244 }
2245 
2246 EmbeddingBagByteRowwiseOffsetsNode *
createEmbeddingBagByteRowwiseOffsets(llvm::StringRef name,NodeValue data,NodeValue weights,NodeValue indices,NodeValue offsets,bool useFP16Accumulation,bool hasEndOffset,LengthsMode lengthsMode,float avgLength)2247 Function::createEmbeddingBagByteRowwiseOffsets(
2248     llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
2249     NodeValue offsets, bool useFP16Accumulation, bool hasEndOffset,
2250     LengthsMode lengthsMode, float avgLength) {
2251   std::vector<dim_t> segmentDims(offsets.dims().begin(), offsets.dims().end());
2252   // If hasEndOffset the last offset is just for marking the end of the last
2253   // segment.
2254   if (hasEndOffset) {
2255     segmentDims[0] -= 1;
2256   }
2257   auto outTy = getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, segmentDims);
2258   return addNode(new EmbeddingBagByteRowwiseOffsetsNode(
2259       name, outTy, data, weights, indices, offsets, useFP16Accumulation,
2260       hasEndOffset, lengthsMode, avgLength));
2261 }
2262 
createLengthsToRanges(llvm::StringRef name,NodeValue lengths)2263 LengthsToRangesNode *Function::createLengthsToRanges(llvm::StringRef name,
2264                                                      NodeValue lengths) {
2265   ShapeVector outDims({lengths.dims()[0], 2});
2266   auto outTy = getParent()->uniqueTypeWithNewShape(lengths.getType(), outDims);
2267   return addNode(new LengthsToRangesNode(name, outTy, lengths));
2268 }
2269 
2270 LengthsRangeFillNode *
createLengthsRangeFill(llvm::StringRef name,NodeValue lengths,unsigned_t maxOutputSize)2271 Function::createLengthsRangeFill(llvm::StringRef name, NodeValue lengths,
2272                                  unsigned_t maxOutputSize) {
2273   auto outTy =
2274       getParent()->uniqueTypeWithNewShape(lengths.getType(), {maxOutputSize});
2275   return addNode(new LengthsRangeFillNode(name, outTy, lengths));
2276 }
2277 
createSparseToDense(llvm::StringRef name,NodeValue indices,NodeValue values,NodeValue dataToInferDim)2278 SparseToDenseNode *Function::createSparseToDense(llvm::StringRef name,
2279                                                  NodeValue indices,
2280                                                  NodeValue values,
2281                                                  NodeValue dataToInferDim) {
2282   // The dimensions of the output are the same as the values tensor except for
2283   // the first dimension, which should match that of dataToInferDim.
2284   ShapeVector outDims(values.dims().begin(), values.dims().end());
2285   outDims[0] = dataToInferDim.dims()[0];
2286   auto outTy = getParent()->uniqueTypeWithNewShape(values.getType(), outDims);
2287   return addNode(new SparseToDenseNode(name, outTy, indices, values));
2288 }
2289 
createSparseToDenseMask(llvm::StringRef name,NodeValue indices,NodeValue values,NodeValue defaultValue,NodeValue lengths,llvm::ArrayRef<dim_t> mask)2290 SparseToDenseMaskNode *Function::createSparseToDenseMask(
2291     llvm::StringRef name, NodeValue indices, NodeValue values,
2292     NodeValue defaultValue, NodeValue lengths, llvm::ArrayRef<dim_t> mask) {
2293   auto lengthsDims = lengths.dims();
2294   auto valueDims = defaultValue.dims();
2295   ShapeVector outDims = {(dim_t)mask.size()};
2296   // If lengths is 0-dimensional tensor, then there is no batch dimension.
2297   if (lengthsDims.size() > 0) {
2298     outDims.insert(outDims.begin(), lengthsDims[0]);
2299   }
2300   outDims.insert(outDims.end(), valueDims.begin(), valueDims.end());
2301   auto outTy = getParent()->uniqueTypeWithNewShape(values.getType(), outDims);
2302   return addNode(new SparseToDenseMaskNode(name, outTy, indices, values,
2303                                            defaultValue, lengths, mask));
2304 }
2305 
createSave(llvm::StringRef name,NodeValue input)2306 SaveNode *Function::createSave(llvm::StringRef name, NodeValue input) {
2307   auto *dest = getParent()->createPlaceholder(input.getType(), name, false);
2308   return createSave(name, input, dest);
2309 }
2310 
createSave(llvm::StringRef name,NodeValue input,Placeholder * output,bool skipSuffix)2311 SaveNode *Function::createSave(llvm::StringRef name, NodeValue input,
2312                                Placeholder *output, bool skipSuffix) {
2313   return addNode(new SaveNode(skipSuffix ? name.str() : (name + "_save").str(),
2314                               input, output));
2315 }
2316 
2317 QuantizationProfileNode *
createQuantizationProfile(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t numHistogramBins)2318 Function::createQuantizationProfile(PlaceholderBindings &bindings,
2319                                     llvm::StringRef name, NodeValue input,
2320                                     dim_t numHistogramBins) {
2321   auto *histogram = getParent()->createPlaceholder(
2322       ElemKind::FloatTy, {numHistogramBins}, "histogram_" + name.str(), false);
2323   bindings.allocate(histogram)->zero();
2324   // Intermediate data used for histogram calculations.
2325   // Min tensor value seen so far is kept on the first position.
2326   // Max tensor value seen so far is kept on the second position.
2327   auto *computationInfoPH = getParent()->createPlaceholder(
2328       ElemKind::FloatTy, {2}, "CI_" + name.str(), false);
2329   bindings.allocate(computationInfoPH);
2330   auto *computationInfoTensor = bindings.get(computationInfoPH);
2331   auto handle = computationInfoTensor->getHandle<float>();
2332   handle.raw(0) = std::numeric_limits<float>::max();
2333   handle.raw(1) = std::numeric_limits<float>::lowest();
2334 
2335   return addNode(new QuantizationProfileNode(
2336       "QI_" + name.str(), input, histogram, computationInfoPH,
2337       input.getNode()->getName().str(), input.getResNo()));
2338 }
2339 
2340 IntLookupTableNode *
createIntLookupTable(llvm::StringRef name,NodeValue input,llvm::ArrayRef<int8_t> initValues,TypeRef outTy)2341 Function::createIntLookupTable(llvm::StringRef name, NodeValue input,
2342                                llvm::ArrayRef<int8_t> initValues,
2343                                TypeRef outTy) {
2344   auto *mapping = getParent()->createConstant(
2345       ElemKind::Int8QTy, {(dim_t)initValues.size()}, outTy->getScale(),
2346       outTy->getOffset(), "mapping");
2347   mapping->getHandle<int8_t>() = initValues;
2348 
2349   return addNode(new IntLookupTableNode(name, outTy, input, mapping));
2350 }
2351 
createIntTanh(llvm::StringRef name,NodeValue input,TypeRef outTy)2352 IntLookupTableNode *Function::createIntTanh(llvm::StringRef name,
2353                                             NodeValue input, TypeRef outTy) {
2354   static int8_t mapping[] = {
2355       -128, -127, -126, -126, -126, -126, -126, -126, -126, -126, -126, -126,
2356       -126, -126, -126, -126, -126, -126, -126, -126, -125, -125, -125, -125,
2357       -125, -125, -125, -125, -125, -125, -125, -124, -124, -124, -124, -124,
2358       -124, -124, -123, -123, -123, -123, -123, -123, -122, -122, -122, -122,
2359       -121, -121, -121, -120, -120, -120, -120, -119, -119, -118, -118, -118,
2360       -117, -117, -116, -116, -115, -115, -114, -114, -113, -112, -112, -111,
2361       -110, -109, -109, -108, -107, -106, -105, -104, -103, -102, -101, -100,
2362       -99,  -98,  -96,  -95,  -94,  -92,  -91,  -89,  -88,  -86,  -85,  -83,
2363       -81,  -79,  -77,  -76,  -74,  -72,  -69,  -67,  -65,  -63,  -61,  -58,
2364       -56,  -53,  -51,  -48,  -46,  -43,  -41,  -38,  -35,  -32,  -29,  -27,
2365       -24,  -21,  -18,  -15,  -12,  -9,   -6,   -3,   0,    3,    6,    9,
2366       12,   15,   18,   21,   24,   27,   29,   32,   35,   38,   41,   43,
2367       46,   48,   51,   53,   56,   58,   61,   63,   65,   67,   69,   72,
2368       74,   76,   77,   79,   81,   83,   85,   86,   88,   89,   91,   92,
2369       94,   95,   96,   98,   99,   100,  101,  102,  103,  104,  105,  106,
2370       107,  108,  109,  109,  110,  111,  112,  112,  113,  114,  114,  115,
2371       115,  116,  116,  117,  117,  118,  118,  118,  119,  119,  120,  120,
2372       120,  120,  121,  121,  121,  122,  122,  122,  122,  123,  123,  123,
2373       123,  123,  123,  124,  124,  124,  124,  124,  124,  124,  125,  125,
2374       125,  125,  125,  125,  125,  125,  125,  125,  125,  126,  126,  126,
2375       126,  126,  126,  126,  126,  126,  126,  126,  126,  126,  126,  126,
2376       126,  126,  126,  127};
2377 
2378   return createIntLookupTable(name, input, mapping, outTy);
2379 }
2380 
createIntSigmoid(llvm::StringRef name,NodeValue input,TypeRef outTy)2381 IntLookupTableNode *Function::createIntSigmoid(llvm::StringRef name,
2382                                                NodeValue input, TypeRef outTy) {
2383   static int8_t mapping[] = {
2384       -128, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127,
2385       -127, -127, -127, -127, -127, -127, -127, -126, -126, -126, -126, -126,
2386       -126, -126, -126, -126, -126, -126, -125, -125, -125, -125, -125, -125,
2387       -125, -125, -124, -124, -124, -124, -124, -123, -123, -123, -123, -122,
2388       -122, -122, -122, -121, -121, -121, -120, -120, -120, -119, -119, -118,
2389       -118, -118, -117, -117, -116, -115, -115, -114, -114, -113, -112, -112,
2390       -111, -110, -109, -109, -108, -107, -106, -105, -104, -103, -102, -101,
2391       -99,  -98,  -97,  -96,  -94,  -93,  -91,  -90,  -88,  -87,  -85,  -83,
2392       -82,  -80,  -78,  -76,  -74,  -72,  -70,  -68,  -66,  -63,  -61,  -59,
2393       -56,  -54,  -51,  -49,  -46,  -44,  -41,  -38,  -36,  -33,  -30,  -27,
2394       -24,  -21,  -18,  -15,  -12,  -9,   -6,   -3,   -1,   2,    5,    8,
2395       11,   14,   17,   20,   23,   26,   29,   32,   35,   37,   40,   43,
2396       45,   48,   50,   53,   55,   58,   60,   62,   65,   67,   69,   71,
2397       73,   75,   77,   79,   81,   82,   84,   86,   87,   89,   90,   92,
2398       93,   95,   96,   97,   98,   100,  101,  102,  103,  104,  105,  106,
2399       107,  108,  108,  109,  110,  111,  111,  112,  113,  113,  114,  114,
2400       115,  116,  116,  117,  117,  117,  118,  118,  119,  119,  119,  120,
2401       120,  120,  121,  121,  121,  121,  122,  122,  122,  122,  123,  123,
2402       123,  123,  123,  124,  124,  124,  124,  124,  124,  124,  124,  125,
2403       125,  125,  125,  125,  125,  125,  125,  125,  125,  125,  126,  126,
2404       126,  126,  126,  126,  126,  126,  126,  126,  126,  126,  126,  126,
2405       126,  126,  126,  127};
2406 
2407   return createIntLookupTable(name, input, mapping, outTy);
2408 }
2409 
createTopK(llvm::StringRef name,NodeValue input,unsigned_t k,ElemKind outIndicesTyKind)2410 TopKNode *Function::createTopK(llvm::StringRef name, NodeValue input,
2411                                unsigned_t k, ElemKind outIndicesTyKind) {
2412   auto inDims = input.dims();
2413   assert(inDims.size() > 0);
2414   assert(k <= inDims.back());
2415   ShapeVector outDims(inDims.begin(), inDims.end());
2416   outDims.back() = k;
2417   auto OT = getParent()->uniqueTypeWithNewShape(input.getType(), outDims);
2418   return addNode(new TopKNode(
2419       name, OT, getParent()->uniqueType(outIndicesTyKind, outDims), input, k));
2420 }
2421 
createTopK(llvm::StringRef name,NodeValue input,unsigned_t k)2422 TopKNode *Function::createTopK(llvm::StringRef name, NodeValue input,
2423                                unsigned_t k) {
2424   return createTopK(name, input, k, ElemKind::Int64ITy);
2425 }
2426 
createArgMax(llvm::StringRef name,NodeValue input,unsigned_t axis,bool keepDims,ElemKind elemTy)2427 ArgMaxNode *Function::createArgMax(llvm::StringRef name, NodeValue input,
2428                                    unsigned_t axis, bool keepDims,
2429                                    ElemKind elemTy) {
2430   ShapeVector outDims = reduceDims(input.dims(), {axis}, keepDims);
2431   auto OT = getParent()->uniqueType(elemTy, outDims);
2432   return addNode(new ArgMaxNode(name, OT, input, axis, keepDims));
2433 }
2434 
createArgMin(llvm::StringRef name,NodeValue input,unsigned_t axis,bool keepDims,ElemKind elemTy)2435 ArgMinNode *Function::createArgMin(llvm::StringRef name, NodeValue input,
2436                                    unsigned_t axis, bool keepDims,
2437                                    ElemKind elemTy) {
2438   ShapeVector outDims = reduceDims(input.dims(), {axis}, keepDims);
2439   auto OT = getParent()->uniqueType(elemTy, outDims);
2440   return addNode(new ArgMinNode(name, OT, input, axis, keepDims));
2441 }
2442 
createGather(llvm::StringRef name,NodeValue data,NodeValue indices,unsigned_t batchDims)2443 GatherNode *Function::createGather(llvm::StringRef name, NodeValue data,
2444                                    NodeValue indices, unsigned_t batchDims) {
2445   auto dDims = data.dims();
2446   auto iDims = indices.dims();
2447   assert(dDims.size() > batchDims);
2448   ShapeVector outDims;
2449   outDims.insert(outDims.end(), dDims.begin(), dDims.begin() + batchDims);
2450   outDims.insert(outDims.end(), iDims.begin(), iDims.end());
2451   outDims.insert(outDims.end(), dDims.begin() + batchDims + 1, dDims.end());
2452   return addNode(new GatherNode(
2453       name, getParent()->uniqueTypeWithNewShape(data.getType(), outDims), data,
2454       indices, batchDims));
2455 }
2456 
createGatherRanges(llvm::StringRef name,NodeValue data,NodeValue ranges,unsigned_t maxOutputSize)2457 GatherRangesNode *Function::createGatherRanges(llvm::StringRef name,
2458                                                NodeValue data, NodeValue ranges,
2459                                                unsigned_t maxOutputSize) {
2460   auto numRanges = ranges.dims()[0];
2461   return addNode(new GatherRangesNode(
2462       name,
2463       /*OutputTy=*/
2464       getParent()->uniqueTypeWithNewShape(data.getType(), {maxOutputSize}),
2465       /*LengthsTy=*/
2466       getParent()->uniqueTypeWithNewShape(ranges.getType(), numRanges), data,
2467       ranges));
2468 }
2469 
createScatterData(llvm::StringRef name,NodeValue data,NodeValue indices,NodeValue slices,bool cumulative)2470 ScatterDataNode *Function::createScatterData(llvm::StringRef name,
2471                                              NodeValue data, NodeValue indices,
2472                                              NodeValue slices,
2473                                              bool cumulative) {
2474   return addNode(new ScatterDataNode(name, data, indices, slices, cumulative));
2475 }
2476 
createBatchOneHot(llvm::StringRef name,NodeValue data,NodeValue lengths,NodeValue values)2477 BatchOneHotNode *Function::createBatchOneHot(llvm::StringRef name,
2478                                              NodeValue data, NodeValue lengths,
2479                                              NodeValue values) {
2480   auto outTy = getParent()->uniqueTypeWithNewShape(
2481       data.getType(), {data.dims()[0], values.dims()[0]});
2482   return addNode(new BatchOneHotNode(name, outTy, data, lengths, values));
2483 }
2484 
createSpaceToDepth(llvm::StringRef name,NodeValue input,unsigned blockSize)2485 SpaceToDepthNode *Function::createSpaceToDepth(llvm::StringRef name,
2486                                                NodeValue input,
2487                                                unsigned blockSize) {
2488   assert(blockSize > 0 && "BlockSize must be >= 1.");
2489 
2490   auto inputDim = input.dims();
2491   assert(inputDim.size() == 4 && "Dimension size of 4 is expected.");
2492   assert((inputDim[1] % blockSize == 0 && inputDim[2] % blockSize == 0) &&
2493          "Height and Width needs to be multiple of blockSize.");
2494   std::vector<dim_t> newDim = {inputDim[0], inputDim[1] / blockSize,
2495                                inputDim[2] / blockSize,
2496                                inputDim[3] * blockSize * blockSize};
2497   auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim);
2498   return addNode(new SpaceToDepthNode(name, outTy, input, blockSize));
2499 }
2500 
createUpsample(llvm::StringRef name,NodeValue input,dim_t numLeadingDims)2501 ReshapeNode *Function::createUpsample(llvm::StringRef name, NodeValue input,
2502                                       dim_t numLeadingDims) {
2503   auto dims = input.dims();
2504   DCHECK_LE(numLeadingDims, dims.size())
2505       << "numLeadingDims " << numLeadingDims
2506       << " must be less than total num dims " << dims.size();
2507   dim_t dim0Dims = 1;
2508   dim_t dim1Dims = 1;
2509   for (dim_t d = 0; d < dims.size(); d++) {
2510     dim0Dims *= dims[d];
2511   }
2512 
2513   Node *cur = input;
2514   for (dim_t d = 0; d < numLeadingDims; d++) {
2515     auto *reshaped =
2516         createReshape(name.str() + "_dim_reshape", cur, {dim0Dims, dim1Dims});
2517     cur =
2518         createTile(name.str() + "_tile", reshaped, /* tiles */ 2, /* axis */ 1);
2519     dim_t sz = dims[dims.size() - d - 1];
2520     dim0Dims /= sz;
2521     dim1Dims *= 2 * sz;
2522   }
2523   std::vector<dim_t> outDims(dims.begin(), dims.end());
2524   for (dim_t d = dims.size() - numLeadingDims; d < dims.size(); d++) {
2525     outDims[d] *= 2;
2526   }
2527   return createReshape(name.str() + "_last_reshape", cur, outDims);
2528 }
2529 
createResizeNearest(llvm::StringRef name,NodeValue input,llvm::ArrayRef<float> scale)2530 ResizeNearestNode *Function::createResizeNearest(llvm::StringRef name,
2531                                                  NodeValue input,
2532                                                  llvm::ArrayRef<float> scale) {
2533   auto inputDim = input.dims();
2534   DCHECK_EQ(inputDim.size(), scale.size())
2535       << "Input Dimension size: " << inputDim.size()
2536       << " Scale size: " << scale.size() << " should be same.";
2537 
2538   std::vector<dim_t> newDim;
2539 
2540   for (size_t i = 0; i < scale.size(); i++) {
2541     auto newD = dim_t(std::floor(inputDim[i] * scale[i]));
2542     DCHECK_GT(newD, 0) << "Scaled dim is " << newD
2543                        << ", Scaled value needs to be larger than 0.";
2544     newDim.push_back(newD);
2545   }
2546 
2547   auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim);
2548   return addNode(new ResizeNearestNode(name, outTy, input, scale));
2549 }
2550 
createResizeNearest(llvm::StringRef name,NodeValue input,TypeRef outTy)2551 ResizeNearestNode *Function::createResizeNearest(llvm::StringRef name,
2552                                                  NodeValue input,
2553                                                  TypeRef outTy) {
2554   auto inputDim = input.dims();
2555   auto outputDim = outTy->dims();
2556   DCHECK_EQ(inputDim.size(), outputDim.size())
2557       << "Input dimension size: " << inputDim.size()
2558       << " output dimension size: " << outputDim.size() << " should be same.";
2559 
2560   std::vector<float> scales;
2561   for (size_t i = 0; i < inputDim.size(); i++) {
2562     float scale = (outputDim[i] / (float)inputDim[i]);
2563     DCHECK_GT(scale, 0.0) << "Scale: " << scale
2564                           << ", Scale larger than 0 is expected.";
2565     scales.push_back(scale);
2566   }
2567 
2568   return addNode(new ResizeNearestNode(name, outTy, input, scales));
2569 }
2570 
2571 ResizeBilinearNode *
createResizeBilinear(llvm::StringRef name,NodeValue input,llvm::ArrayRef<float> scale)2572 Function::createResizeBilinear(llvm::StringRef name, NodeValue input,
2573                                llvm::ArrayRef<float> scale) {
2574   auto inputDim = input.dims();
2575   DCHECK_EQ(inputDim.size(), scale.size())
2576       << "Input Dimension size: " << inputDim.size()
2577       << " Scale size: " << scale.size() << " should be same.";
2578 
2579   std::vector<dim_t> newDim;
2580 
2581   for (size_t i = 0; i < scale.size(); i++) {
2582     auto newD = dim_t(std::floor(inputDim[i] * scale[i]));
2583     DCHECK_GT(newD, 0) << "Scaled dim is " << newD
2584                        << ", Scaled value needs to be larger than 0.";
2585     newDim.push_back(newD);
2586   }
2587 
2588   auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim);
2589   return addNode(new ResizeBilinearNode(name, outTy, input, scale));
2590 }
2591 
createResizeBilinear(llvm::StringRef name,NodeValue input,TypeRef outTy)2592 ResizeBilinearNode *Function::createResizeBilinear(llvm::StringRef name,
2593                                                    NodeValue input,
2594                                                    TypeRef outTy) {
2595   auto inputDim = input.dims();
2596   auto outputDim = outTy->dims();
2597   DCHECK_EQ(inputDim.size(), outputDim.size())
2598       << "Input dimension size: " << inputDim.size()
2599       << " output dimension size: " << outputDim.size() << " should be same.";
2600 
2601   std::vector<float> scales;
2602   for (size_t i = 0; i < inputDim.size(); i++) {
2603     float scale = (outputDim[i] / (float)inputDim[i]);
2604     DCHECK_GT(scale, 0.0) << "Scale: " << scale
2605                           << ", Scale larger than 0 is expected.";
2606     scales.push_back(scale);
2607   }
2608 
2609   return addNode(new ResizeBilinearNode(name, outTy, input, scales));
2610 }
2611 
createQuantize(llvm::StringRef name,NodeValue input,TypeRef outTy)2612 QuantizeNode *Function::createQuantize(llvm::StringRef name, NodeValue input,
2613                                        TypeRef outTy) {
2614   assert(input.getType()->isFPType() && "Input must be a floating type");
2615   assert(outTy->isQuantizedType() && "Output must be a quantized type");
2616   assert(input.dims().equals(outTy->dims()) &&
2617          "Different dimensions for input and output");
2618 
2619   return addNode(
2620       new QuantizeNode(name, getParent()->uniqueType(*outTy), input));
2621 }
2622 
createDequantize(llvm::StringRef name,NodeValue input,ElemKind k)2623 DequantizeNode *Function::createDequantize(llvm::StringRef name,
2624                                            NodeValue input, ElemKind k) {
2625   assert(input.getType()->isQuantizedType() &&
2626          "Input must be a quantized type");
2627   assert(isFloatElemKind(k) && "Result must be float type.");
2628   ShapeVector outShape(input.dims().begin(), input.dims().end());
2629   if (input.getElementType() == ElemKind::UInt8FusedQTy) {
2630     assert(outShape.size() == 2 && "Fused tensors should be 2D");
2631     assert(outShape[1] > 2 * sizeof(float) &&
2632            "Expected space for per-row scale/offset");
2633     outShape[1] -= 2 * sizeof(float);
2634   }
2635   TypeRef outTy = getParent()->uniqueType(Type(k, outShape));
2636   return createDequantize(name, input, outTy);
2637 }
2638 
createDequantize(llvm::StringRef name,NodeValue input,TypeRef outTy)2639 DequantizeNode *Function::createDequantize(llvm::StringRef name,
2640                                            NodeValue input, TypeRef outTy) {
2641   assert(input.getType()->isQuantizedType() &&
2642          "Input must be a quantized type");
2643   assert(outTy->isFPType() && "Output should be an FP type");
2644   return addNode(new DequantizeNode(name, outTy, input));
2645 }
2646 
createRescaleQuantized(llvm::StringRef name,NodeValue input,TypeRef outTy)2647 RescaleQuantizedNode *Function::createRescaleQuantized(llvm::StringRef name,
2648                                                        NodeValue input,
2649                                                        TypeRef outTy) {
2650   assert(input.getType()->isQuantizedType() &&
2651          "Input must be a quantized type");
2652   assert(outTy->isQuantizedType() && "Output must be a quantized type");
2653   assert(input.dims().equals(outTy->dims()) &&
2654          "Different dimensions for input and output");
2655 
2656   return addNode(
2657       new RescaleQuantizedNode(name, getParent()->uniqueType(*outTy), input));
2658 }
2659 
createWeightedSum(llvm::StringRef name,llvm::ArrayRef<NodeValue> data,llvm::ArrayRef<NodeValue> weights)2660 Node *Function::createWeightedSum(llvm::StringRef name,
2661                                   llvm::ArrayRef<NodeValue> data,
2662                                   llvm::ArrayRef<NodeValue> weights) {
2663   assert(data.size() == weights.size() &&
2664          "Must have same number of data and weights.");
2665   assert(data.size() > 0 && "No inputs provided.");
2666 
2667   const auto *outTy = data[0].getType();
2668 
2669   // Create a zero splat to bootstrap the adding chain.
2670   Node *currAdd = createSplat(name.str() + ".splat", outTy, 0.);
2671 
2672   for (size_t i = 0, e = data.size(); i < e; i++) {
2673     assert(weights[i].getType()->size() == 1 &&
2674            "Each provided weight node must be size 1.");
2675     assert(outTy == data[i].getType() &&
2676            "All data nodes must have the same type.");
2677 
2678     // Broadcast the current weight to same shape as the data.
2679     auto *bcastW =
2680         createBroadcast(name.str() + ".bcastWeight" + std::to_string(i),
2681                         weights[i], outTy->dims(), /* axis */ 0);
2682 
2683     // Element-wise multiply the broadcasted weight by the data.
2684     auto *scaledD =
2685         createMul(name.str() + ".mul" + std::to_string(i), bcastW, data[i]);
2686 
2687     // Element-wise add the scaled data to the running total.
2688     currAdd =
2689         createAdd(name.str() + ".add" + std::to_string(i), scaledD, currAdd);
2690   }
2691 
2692   // Return the final weighted sum via the last add in the chain.
2693   return currAdd;
2694 }
2695 
createBatchBoxCox(llvm::StringRef name,NodeValue data,NodeValue lambda1,NodeValue lambda2,float epsilon)2696 Node *Function::createBatchBoxCox(llvm::StringRef name, NodeValue data,
2697                                   NodeValue lambda1, NodeValue lambda2,
2698                                   float epsilon) {
2699   assert((lambda1.dims() == lambda2.dims()) &&
2700          "lambda1 and lambda2 must have the same shape");
2701   assert((lambda1.getType()->getElementType() == lambda2.getElementType()) &&
2702          "lambda1 and lambda2 must have the same element type");
2703   assert((lambda1.getType()->getElementType() == data.getElementType()) &&
2704          "data and lambdas must have the same element type");
2705   assert((lambda1.dims().size() == 1) && "lambda1 and lambda2 must be vectors");
2706   assert((data.dims().size() == 2) && "data must be a matrix");
2707   assert((data.dims()[1] == lambda1.dims()[0]) &&
2708          "data, lambda1 and lambda2 must have the same number of rows");
2709 
2710   return addNode(new BatchBoxCoxNode(name, data, lambda1, lambda2, epsilon));
2711 }
2712 
createClip(llvm::StringRef name,NodeValue input,TypeRef outTy,float min,float max)2713 ClipNode *Function::createClip(llvm::StringRef name, NodeValue input,
2714                                TypeRef outTy, float min, float max) {
2715   return addNode(new ClipNode(name, outTy, input, min, max));
2716 }
2717 
createClip(llvm::StringRef name,NodeValue input,float min,float max)2718 ClipNode *Function::createClip(llvm::StringRef name, NodeValue input, float min,
2719                                float max) {
2720   return addNode(new ClipNode(name, input.getType(), input, min, max));
2721 }
2722 
createClipMinMaxFP16(llvm::StringRef name,NodeValue input)2723 ClipNode *Function::createClipMinMaxFP16(llvm::StringRef name,
2724                                          NodeValue input) {
2725   constexpr float float16Min = -65504.0f;
2726   constexpr float float16Max = 65504.0f;
2727   return createClip(name, input, float16Min, float16Max);
2728 }
2729 
createClipMinMaxBFloat16(llvm::StringRef name,NodeValue input)2730 ClipNode *Function::createClipMinMaxBFloat16(llvm::StringRef name,
2731                                              NodeValue input) {
2732   constexpr float float16Min = FLT_MIN;
2733   constexpr float float16Max = FLT_MAX;
2734   return createClip(name, input, float16Min, float16Max);
2735 }
2736 
2737 //===----------------------------------------------------------------------===//
2738 //                   Placeholder-builder methods.
2739 //===----------------------------------------------------------------------===//
2740 
createBatchNormalization(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,unsigned_t channelIdx,float epsilon,float momentum)2741 BatchNormalizationNode *Function::createBatchNormalization(
2742     PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
2743     unsigned_t channelIdx, float epsilon, float momentum) {
2744   // Figure out how many channels are in the tensor.
2745   dim_t channels = input.dims()[channelIdx];
2746 
2747   ElemKind inputTy = input.getType()->getElementType();
2748 
2749   // Allocate the learnable parameters beta and gamma.
2750   auto *beta =
2751       getParent()->createPlaceholder(inputTy, {channels}, "beta", true);
2752   bindings.allocate(beta)->init(Tensor::InitKind::Broadcast, 0.1, getPRNG());
2753 
2754   auto *scale =
2755       getParent()->createPlaceholder(inputTy, {channels}, "scale", true);
2756   bindings.allocate(scale)->init(Tensor::InitKind::Broadcast, 0.001, getPRNG());
2757 
2758   auto *mean =
2759       getParent()->createPlaceholder(inputTy, {channels}, "mean", false);
2760   bindings.allocate(mean)->zero();
2761 
2762   auto *variance =
2763       getParent()->createPlaceholder(inputTy, {channels}, "variance", false);
2764   bindings.allocate(variance)->init(Tensor::InitKind::Broadcast, 1.0,
2765                                     getPRNG());
2766 
2767   return createBatchNormalization(name, input, beta, scale, mean, variance,
2768                                   channelIdx, epsilon, momentum);
2769 }
2770 
createConv(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outChannels,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group,unsigned_t dilation,ConvolutionLayout layout)2771 ConvolutionNode *Function::createConv(
2772     PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
2773     dim_t outChannels, llvm::ArrayRef<unsigned_t> kernels,
2774     llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
2775     unsigned_t group, unsigned_t dilation, ConvolutionLayout layout) {
2776   ShapeNHWC idim = ShapeNHWC(input.dims());
2777   ShapeHW kdim(kernels);
2778   PaddingTLBR pdim(pads);
2779   (void)pdim;
2780   assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
2781          (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
2782          "buffer too small for selected stride");
2783 
2784   assert(group > 0 && "group should be larger than 0");
2785   assert(idim.c % group == 0 && "channels number must be divisible by groups");
2786   assert(outChannels % group == 0 && "outChannels must be divisible by groups");
2787 
2788   // Calculate the size and allocate the output buffer.
2789   auto outSz = calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides,
2790                                            pads, dilation);
2791 
2792   std::array<dim_t, 4> outDims = {
2793       {idim.n, outSz.first, outSz.second, outChannels}};
2794 
2795   // Allocate the Filter and Bias tensors.
2796   std::array<dim_t, 4> filterDim = {
2797       {outChannels, kdim.height, kdim.width, idim.c / group}};
2798   size_t fanIn = kdim.height * kdim.width * idim.c;
2799   ElemKind inputTy = input.getType()->getElementType();
2800   assert(isFloatElemKind(inputTy) && "Convolution on non-floating point type?");
2801   auto *filter =
2802       getParent()->createPlaceholder(inputTy, filterDim, "filter", true);
2803   bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn,
2804                                   getPRNG());
2805 
2806   auto *bias =
2807       getParent()->createPlaceholder(inputTy, {outChannels}, "bias", true);
2808   bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1,
2809                                 getPRNG());
2810 
2811   auto OT = getParent()->uniqueType(inputTy, outDims);
2812 
2813   return addNode(new ConvolutionNode(name, OT, input, filter, bias, kernels,
2814                                      strides, pads, group, dilation, layout,
2815                                      FusedActivation::NONE));
2816 }
2817 
createConv(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outChannels,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group,unsigned_t dilation,ConvolutionLayout layout)2818 ConvolutionNode *Function::createConv(PlaceholderBindings &bindings,
2819                                       llvm::StringRef name, NodeValue input,
2820                                       dim_t outChannels, unsigned_t kernel,
2821                                       unsigned_t stride, unsigned_t pad,
2822                                       unsigned_t group, unsigned_t dilation,
2823                                       ConvolutionLayout layout) {
2824   llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
2825   llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
2826   llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
2827   return createConv(bindings, name, input, outChannels, kernels, strides, pads,
2828                     group, dilation, layout);
2829 }
2830 
createConv3D(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outChannels,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group)2831 Convolution3DNode *Function::createConv3D(PlaceholderBindings &bindings,
2832                                           llvm::StringRef name, NodeValue input,
2833                                           dim_t outChannels,
2834                                           llvm::ArrayRef<unsigned_t> kernels,
2835                                           llvm::ArrayRef<unsigned_t> strides,
2836                                           llvm::ArrayRef<unsigned_t> pads,
2837                                           unsigned_t group) {
2838   ShapeNTHWC idim(input.dims());
2839   ShapeTHW kdim(kernels);
2840 
2841   assert(group > 0 && "group should be larger than 0");
2842   assert(idim.c % group == 0 && "channels number must be divisible by groups");
2843   assert(outChannels % group == 0 && "outChannels must be divisible by groups");
2844 
2845   // Calculate the size and allocate the output buffer.
2846   auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels,
2847                                              strides, pads);
2848 
2849   std::array<dim_t, 5> outDims = {
2850       {idim.n, outSz.temporal_frames, outSz.height, outSz.width, outChannels}};
2851 
2852   // Allocate the Filter and Bias tensors.
2853   std::array<dim_t, 5> filterDim = {{outChannels, kdim.temporal_frames,
2854                                      kdim.height, kdim.width, idim.c / group}};
2855 
2856   dim_t fanIn = kdim.temporal_frames * kdim.height * kdim.width * idim.c;
2857   ElemKind inputTy = input.getType()->getElementType();
2858   assert(isFloatElemKind(inputTy) &&
2859          "Convolution3D on non-floating point type?");
2860   auto *filter =
2861       getParent()->createPlaceholder(inputTy, filterDim, "filter", true);
2862   bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn,
2863                                   getPRNG());
2864 
2865   auto *bias =
2866       getParent()->createPlaceholder(inputTy, {outChannels}, "bias", true);
2867   bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1,
2868                                 getPRNG());
2869 
2870   auto OT = getParent()->uniqueType(inputTy, outDims);
2871 
2872   assertConv3DDims(input, filter, bias, kernels, strides, pads, group);
2873 
2874   return addNode(new Convolution3DNode(name, OT, input, filter, bias, kernels,
2875                                        strides, pads, group));
2876 }
2877 
createConv3D(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,size_t outChannels,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group)2878 Convolution3DNode *Function::createConv3D(PlaceholderBindings &bindings,
2879                                           llvm::StringRef name, NodeValue input,
2880                                           size_t outChannels, unsigned_t kernel,
2881                                           unsigned_t stride, unsigned_t pad,
2882                                           unsigned_t group) {
2883   llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad};
2884   llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride};
2885   llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel};
2886   return createConv3D(bindings, name, input, outChannels, kernels, strides,
2887                       pads, group);
2888 }
2889 
createChannelwiseQuantizedConv(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,NodeValue filterScales,NodeValue filterOffsets,NodeValue biasScales,NodeValue biasOffsets,TypeRef outTy,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group,unsigned_t dilation,bool quantizeFilter,bool quantizeBias,quantization::Schema schema,ElemKind filterElemQTy,ElemKind biasElemQTy)2890 ChannelwiseQuantizedConvolutionNode *Function::createChannelwiseQuantizedConv(
2891     llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
2892     NodeValue filterScales, NodeValue filterOffsets, NodeValue biasScales,
2893     NodeValue biasOffsets, TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
2894     llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
2895     unsigned_t group, unsigned_t dilation, bool quantizeFilter,
2896     bool quantizeBias, quantization::Schema schema, ElemKind filterElemQTy,
2897     ElemKind biasElemQTy) {
2898 
2899   // Validate dimensions.
2900   bool isConv3D = (input.getType()->dims().size() == 5);
2901   if (isConv3D) {
2902     assertConv3DDims(input, filter, bias, kernels, strides, pads, group);
2903   } else {
2904     assertConvDims(input, filter, bias, kernels, strides, pads, group);
2905   }
2906 
2907   // Validate bias precision.
2908   auto biasElemKind = bias.getElementType();
2909   DCHECK(biasElemKind == ElemKind::Int8QTy ||
2910          biasElemKind == ElemKind::Int32QTy ||
2911          biasElemKind == ElemKind::FloatTy)
2912       << "Unsupported element type for ChannelwiseQuantizedConvolution bias: "
2913       << Type::getElementName(biasElemKind).str();
2914 
2915   // Validate filter precision.
2916   auto filterElemKind = filter.getElementType();
2917   DCHECK(filterElemKind == ElemKind::Int8QTy ||
2918          filterElemKind == ElemKind::FloatTy)
2919       << "Unsupported element type for ChannelwiseQuantizedConvolution "
2920       << "filter: " << Type::getElementName(filterElemKind).str();
2921 
2922   DCHECK(dyn_cast<Constant>(bias.getNode()))
2923       << "Bias input to ChannelwiseQuantizedConvolutionNode must be a Constant";
2924 
2925   DCHECK(dyn_cast<Constant>(filter.getNode()))
2926       << "Filter input to ChannelwiseQuantizedConvolutionNode must be a "
2927          "Constant";
2928 
2929   DCHECK(!filterScales.getNode() || dyn_cast<Constant>(filterScales.getNode()))
2930       << "Filter scales input to ChannelwiseQuantizedConvolutionNode must be "
2931          "null or Constant";
2932 
2933   DCHECK(!filterOffsets.getNode() ||
2934          dyn_cast<Constant>(filterOffsets.getNode()))
2935       << "Filter offsets input to ChannelwiseQuantizedConvolutionNode must be "
2936          "null or Constant";
2937 
2938   DCHECK(!biasScales.getNode() || dyn_cast<Constant>(biasScales.getNode()))
2939       << "Bias scales input to ChannelwiseQuantizedConvolutionNode must be "
2940          "null or Constant";
2941 
2942   DCHECK(!biasOffsets.getNode() || dyn_cast<Constant>(biasOffsets.getNode()))
2943       << "Bias offsets input to ChannelwiseQuantizedConvolutionNode must be "
2944          "null or Constant";
2945 
2946   // Number of output channels.
2947   dim_t numChannels = outTy->dims().back();
2948   dim_t qDim = 0;
2949   dim_t qStep = 1;
2950 
2951   // If input filter is FLOAT and filterScales/filterOffsets are NOT provided
2952   // then compute them automatically for given schema and filterElemQTy.
2953   // If input filter is QUANTIZED then filterScales/filterOffsets are mandatory.
2954   if (!filterScales.getNode() || !filterOffsets.getNode()) {
2955     DCHECK(filterElemKind == ElemKind::FloatTy)
2956         << "ChannelwiseQuantizedConvolution: If the input filter is "
2957         << "quantized then the filter scales/offsets must be provided!";
2958     Constant *filterC = dyn_cast<Constant>(filter.getNode());
2959     Constant *filterScalesC = getParent()->createConstant(
2960         ElemKind::FloatTy, {numChannels}, "filterScales");
2961     Constant *filterOffsetsC = getParent()->createConstant(
2962         ElemKind::Int32ITy, {numChannels}, "filterOffsets");
2963     // Get filter channelwise TensorQuantizationParams.
2964     quantization::getTensorQuantizationParams(
2965         filterC->getPayload(), filterScalesC->getPayloadMutable(),
2966         filterOffsetsC->getPayloadMutable(), schema, filterElemQTy, qDim,
2967         qStep);
2968     filterScales = NodeValue(filterScalesC);
2969     filterOffsets = NodeValue(filterOffsetsC);
2970   }
2971 
2972   // If input filter is FLOAT then quantize channel wise to filterElemQTy.
2973   if (quantizeFilter && filterElemKind == ElemKind::FloatTy) {
2974     Constant *filterC = dyn_cast<Constant>(filter.getNode());
2975     Constant *filterCQ = getParent()->createConstant(
2976         filterElemQTy, filterC->getType()->dims(), 1.0, 0, "filter");
2977     Constant *filterScalesC = dyn_cast<Constant>(filterScales.getNode());
2978     Constant *filterOffsetsC = dyn_cast<Constant>(filterOffsets.getNode());
2979     // Quantize filter channelwise.
2980     filterCQ->getPayloadMutable() = quantization::quantizeTensor(
2981         filterC->getPayload(), filterScalesC->getPayload(),
2982         filterOffsetsC->getPayload(), filterElemQTy, qDim, qStep);
2983     filter = NodeValue(filterCQ);
2984   }
2985 
2986   // If input bias is FLOAT and biasScales/biasOffsets are NOT provided
2987   // then compute them automatically for given schema and biasElemQTy.
2988   // If input bias is QUANTIZED and biasScales/biasOffsets are NOT provided
2989   // then assume the channel wise quantization parameters are implicitly:
2990   // biasScales[i] = inputScale * filterScales[i] and biasOffsets[i] = 0.
2991   if (!biasScales.getNode() || !biasOffsets.getNode()) {
2992     Constant *biasC = dyn_cast<Constant>(bias.getNode());
2993     Constant *biasScalesC = getParent()->createConstant(
2994         ElemKind::FloatTy, {numChannels}, "biasScales");
2995     Constant *biasOffsetsC = getParent()->createConstant(
2996         ElemKind::Int32ITy, {numChannels}, "biasOffsets");
2997     auto biasScalesH = biasScalesC->getPayload().getHandle<float>();
2998     auto biasOffsetsH = biasOffsetsC->getPayload().getHandle<int32_t>();
2999     Constant *filterScalesC = dyn_cast<Constant>(filterScales.getNode());
3000     Constant *filterOffsetsC = dyn_cast<Constant>(filterOffsets.getNode());
3001     auto filterScalesH = filterScalesC->getPayload().getHandle<float>();
3002     auto filterOffsetsH = filterOffsetsC->getPayload().getHandle<int32_t>();
3003     auto inputScale = input.getType()->getScale();
3004     auto inputOffset = input.getType()->getOffset();
3005     if (biasElemKind == ElemKind::FloatTy) {
3006       // Get bias channelwise TensorQuantizationParams.
3007       quantization::getTensorQuantizationParams(
3008           biasC->getPayload(), biasScalesC->getPayloadMutable(),
3009           biasOffsetsC->getPayloadMutable(), schema, biasElemQTy, qDim, qStep);
3010       // Specialize the bias channelwise TensorQuantizationParams.
3011       for (dim_t idx = 0; idx < numChannels; idx++) {
3012         auto biasTQPNew = specializeBiasQuantizationParams(
3013             {biasScalesH.raw(idx), biasOffsetsH.raw(idx)},
3014             {inputScale, inputOffset},
3015             {filterScalesH.raw(idx), filterOffsetsH.raw(idx)}, schema,
3016             biasElemQTy);
3017         biasScalesH.raw(idx) = biasTQPNew.scale;
3018         biasOffsetsH.raw(idx) = biasTQPNew.offset;
3019       }
3020     } else {
3021       // Set implicit bias channelwise TensorQuantizationParams.
3022       for (dim_t idx = 0; idx < numChannels; idx++) {
3023         float filterScale = filterScalesH.raw(idx);
3024         biasScalesH.raw(idx) = inputScale * filterScale;
3025         biasOffsetsH.raw(idx) = 0;
3026       }
3027     }
3028     biasScales = NodeValue(biasScalesC);
3029     biasOffsets = NodeValue(biasOffsetsC);
3030   }
3031 
3032   // If input bias is FLOAT then quantize channel wise to biasElemQTy.
3033   if (quantizeBias && biasElemKind == ElemKind::FloatTy) {
3034     Constant *biasC = dyn_cast<Constant>(bias.getNode());
3035     Constant *biasCQ = getParent()->createConstant(
3036         biasElemQTy, biasC->getType()->dims(), 1.0, 0, "bias");
3037     Constant *biasScalesC = dyn_cast<Constant>(biasScales.getNode());
3038     Constant *biasOffsetsC = dyn_cast<Constant>(biasOffsets.getNode());
3039     // Quantize bias channelwise.
3040     biasCQ->getPayloadMutable() = quantization::quantizeTensor(
3041         biasC->getPayload(), biasScalesC->getPayload(),
3042         biasOffsetsC->getPayload(), biasElemQTy, qDim, qStep);
3043     bias = NodeValue(biasCQ);
3044   }
3045 
3046   auto OT = getParent()->uniqueType(*outTy);
3047   return addNode(new ChannelwiseQuantizedConvolutionNode(
3048       name, OT, input, filter, bias, filterScales, filterOffsets, biasScales,
3049       biasOffsets, kernels, strides, pads, group, dilation));
3050 }
3051 
createConvTranspose(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outChannels,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group,unsigned_t dilation)3052 ConvTransposeNode *Function::createConvTranspose(
3053     PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
3054     dim_t outChannels, llvm::ArrayRef<unsigned_t> kernels,
3055     llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
3056     unsigned_t group, unsigned_t dilation) {
3057   ShapeNHWC idim = ShapeNHWC(input.dims());
3058   ShapeHW kdim(kernels);
3059   PaddingTLBR pdim(pads);
3060   (void)pdim;
3061   assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
3062          (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
3063          "buffer too small for selected stride");
3064 
3065   assert(group > 0 && "group should be larger than 0");
3066   assert(idim.c % group == 0 && "channels number must be divisible by groups");
3067   assert(outChannels % group == 0 && "outChannels must be divisible by groups");
3068 
3069   // Calculate the size and allocate the output buffer.
3070   auto outSz = calculateConvTransposeOutputDims(idim.h, idim.w, kernels,
3071                                                 strides, pads, dilation);
3072 
3073   std::array<dim_t, 4> outDims = {
3074       {idim.n, outSz.first, outSz.second, outChannels}};
3075 
3076   // Allocate the Filter and Bias tensors.
3077   std::array<dim_t, 4> filterDim = {
3078       {outChannels, kdim.height, kdim.width, idim.c / group}};
3079   size_t fanIn = kdim.height * kdim.width * idim.c;
3080   ElemKind inputTy = input.getType()->getElementType();
3081   assert((inputTy == ElemKind::FloatTy || inputTy == ElemKind::Float16Ty) &&
3082          "Convolution on non-floating point type?");
3083   auto *filter =
3084       getParent()->createPlaceholder(inputTy, filterDim, "filter", true);
3085 
3086   auto *bias =
3087       getParent()->createPlaceholder(inputTy, {outChannels}, "bias", true);
3088   bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1,
3089                                 getPRNG());
3090 
3091   bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn,
3092                                   getPRNG());
3093 
3094   auto OT = getParent()->uniqueType(inputTy, outDims);
3095 
3096   return addNode(new ConvTransposeNode(name, OT, input, filter, bias, kernels,
3097                                        strides, pads, group, dilation));
3098 }
3099 
createConvTranspose(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outChannels,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group,unsigned_t dilation)3100 ConvTransposeNode *Function::createConvTranspose(
3101     PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
3102     dim_t outChannels, unsigned_t kernel, unsigned_t stride, unsigned_t pad,
3103     unsigned_t group, unsigned_t dilation) {
3104   llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
3105   llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
3106   llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
3107   return createConvTranspose(bindings, name, input, outChannels, kernels,
3108                              strides, pads, group, dilation);
3109 }
3110 
createConvertTo(llvm::StringRef name,NodeValue input,TypeRef outTy)3111 ConvertToNode *Function::createConvertTo(llvm::StringRef name, NodeValue input,
3112                                          TypeRef outTy) {
3113   return addNode(new ConvertToNode(name, outTy, input));
3114 }
3115 
createConvertTo(llvm::StringRef name,NodeValue input,ElemKind k)3116 ConvertToNode *Function::createConvertTo(llvm::StringRef name, NodeValue input,
3117                                          ElemKind k) {
3118   auto OT = getParent()->uniqueType(k, input.dims());
3119   return addNode(new ConvertToNode(name, OT, input));
3120 }
3121 
3122 FullyConnectedNode *
createFullyConnected(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outDepth,unsigned_t axis)3123 Function::createFullyConnected(PlaceholderBindings &bindings,
3124                                llvm::StringRef name, NodeValue input,
3125                                dim_t outDepth, unsigned_t axis) {
3126   const ElemKind k = input.getType()->getElementType();
3127 
3128   // FC always uses 2D input; flatten if necessary.
3129   if (input.dims().size() != 2) {
3130     input = createFlatten(name.str() + ".reshape2D", input, axis);
3131   }
3132   auto *W = getParent()->createPlaceholder(k, {input.dims()[1], outDepth},
3133                                            "weights", true);
3134   auto *B = getParent()->createPlaceholder(k, {outDepth}, "bias", true);
3135 
3136   bindings.allocate(W)->init(Tensor::InitKind::Xavier, input.dims()[1],
3137                              getPRNG());
3138   bindings.allocate(B)->init(Tensor::InitKind::Broadcast, .1, getPRNG());
3139 
3140   auto OT = getParent()->uniqueType(k, {input.dims()[0], outDepth});
3141   return createFullyConnected(name, input, W, B, OT, axis);
3142 }
3143 
createDotProduct(llvm::StringRef name,NodeValue X,NodeValue Y)3144 Node *Function::createDotProduct(llvm::StringRef name, NodeValue X,
3145                                  NodeValue Y) {
3146   auto XDimsSize = X.dims().size();
3147   (void)XDimsSize;
3148 
3149   assert(X.dims() == Y.dims() && "X and Y must have the same shape");
3150   assert(((XDimsSize == 1) || (XDimsSize == 2)) && "X and Y must be 1D or 2D");
3151 
3152   // Create Mul node.
3153   auto *MN = createMul(name.str() + ".mul", X, Y);
3154 
3155   // If X and Y are 1D, the BatchedReduceAdd node is not needed.
3156   if (XDimsSize == 1) {
3157     return MN;
3158   }
3159 
3160   // Create and return BatchedReduceAdd node.
3161   return createBatchedReduceAdd(name.str() + ".bra", MN, 1);
3162 }
3163 
3164 BatchedPairwiseDotProductNode *
createBatchedPairwiseDotProduct(llvm::StringRef name,llvm::ArrayRef<NodeValue> inputs)3165 Function::createBatchedPairwiseDotProduct(llvm::StringRef name,
3166                                           llvm::ArrayRef<NodeValue> inputs) {
3167   assert(!inputs.empty());
3168   dim_t batchCount = inputs[0].getType()->dims()[0];
3169   dim_t numPairs = inputs.size() * (inputs.size() - 1) / 2;
3170   auto *outTy = getParent()->uniqueTypeWithNewShape(inputs[0].getType(),
3171                                                     {batchCount, numPairs});
3172 
3173   return addNode(new BatchedPairwiseDotProductNode(name, outTy, inputs));
3174 }
3175 
createElementwiseLinear(llvm::StringRef name,NodeValue X,NodeValue w,NodeValue b,unsigned axis)3176 Node *Function::createElementwiseLinear(llvm::StringRef name, NodeValue X,
3177                                         NodeValue w, NodeValue b,
3178                                         unsigned axis) {
3179   auto XDims = X.dims();
3180   auto wDims = w.dims();
3181   auto bDims = b.dims();
3182 
3183   // Suppress release mode unused variable warnings.
3184   (void)wDims;
3185   (void)bDims;
3186 
3187   // Check that the inputs are sensible.
3188   assert(XDims.size() == 2 && "X must be 2D");
3189   assert((axis == 0 || axis == 1) && "axis must be 0 or 1");
3190   assert(wDims.size() == 1 && "w must be 1D");
3191   assert(bDims.size() == 1 && "b must be 1D");
3192   assert(wDims[0] == XDims[axis] &&
3193          "size of w must match input dimension of X");
3194   assert(bDims[0] == XDims[axis] &&
3195          "size of b must match input dimension of X");
3196 
3197   // Broadcast w and b so that they have the same dimensions as X.
3198   auto *broadcastW =
3199       createBroadcast(name.str() + ".broadcastW", w, XDims, axis);
3200   auto *broadcastB =
3201       createBroadcast(name.str() + ".broadcastB", b, XDims, axis);
3202 
3203   // Implement the elementwise linear operation by multiplying X elementwise
3204   // with broadcasted w and adding broadcasted b elementwise.
3205   auto *wX = createMul(name.str() + ".mul", broadcastW, X);
3206   auto *out = createAdd(name.str() + ".add", wX, broadcastB);
3207 
3208   return out;
3209 }
3210 
createGRU(PlaceholderBindings & bindings,llvm::StringRef namePrefix,llvm::ArrayRef<NodeValue> inputs,unsigned batchSize,unsigned hiddenSize,unsigned outputSize,std::vector<NodeValue> & outputs)3211 void Function::createGRU(PlaceholderBindings &bindings,
3212                          llvm::StringRef namePrefix,
3213                          llvm::ArrayRef<NodeValue> inputs, unsigned batchSize,
3214                          unsigned hiddenSize, unsigned outputSize,
3215                          std::vector<NodeValue> &outputs) {
3216   std::string nameBase = namePrefix;
3217   const unsigned timeSteps = inputs.size();
3218   assert(timeSteps > 0 && "empty input");
3219   const unsigned inputSize = inputs.front().dims().back();
3220   assert(inputSize > 0 && "input dimensionality is zero");
3221 
3222   // Initialize the state to zero.
3223   Placeholder *HInit = getParent()->createPlaceholder(
3224       ElemKind::FloatTy, {batchSize, hiddenSize}, "initial_state", false);
3225   bindings.allocate(HInit)->zero();
3226   Node *Ht = HInit;
3227 
3228   // Update gate:
3229   //    Z <- sigmoid(Wxz * x + Whz * h + bz)
3230   // Reset gate:
3231   //    R <- sigmoid(Wxr * x + Whr * h + br)
3232   // Hidden state:
3233   //    h <- Z . h + (1 - Z) tanh (Wxh * x + Whh * (R . h) + bh)
3234 
3235   // update gate
3236   float bUpdate = 0.1;
3237   Placeholder *Wxz = getParent()->createPlaceholder(
3238       ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxz", true);
3239   Placeholder *Whz = getParent()->createPlaceholder(
3240       ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whz", true);
3241   Placeholder *Bz1 = getParent()->createPlaceholder(
3242       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bz1", true);
3243   Placeholder *Bz2 = getParent()->createPlaceholder(
3244       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bz2", true);
3245 
3246   bindings.allocate(Wxz)->init(glow::Tensor::InitKind::Xavier, inputSize,
3247                                getPRNG());
3248   bindings.allocate(Whz)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3249                                getPRNG());
3250   bindings.allocate(Bz1)->init(glow::Tensor::InitKind::Broadcast, bUpdate,
3251                                getPRNG());
3252   bindings.allocate(Bz2)->init(glow::Tensor::InitKind::Broadcast, bUpdate,
3253                                getPRNG());
3254 
3255   // Reset gate.
3256   float bReset = -1.0;
3257   Placeholder *Wxr = getParent()->createPlaceholder(
3258       ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxr", true);
3259   Placeholder *Whr = getParent()->createPlaceholder(
3260       ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whr", true);
3261   Placeholder *Br1 = getParent()->createPlaceholder(
3262       ElemKind::FloatTy, {hiddenSize}, nameBase + ".br1", true);
3263   Placeholder *Br2 = getParent()->createPlaceholder(
3264       ElemKind::FloatTy, {hiddenSize}, nameBase + ".br2", true);
3265 
3266   bindings.allocate(Wxr)->init(glow::Tensor::InitKind::Xavier, inputSize,
3267                                getPRNG());
3268   bindings.allocate(Whr)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3269                                getPRNG());
3270   bindings.allocate(Br1)->init(glow::Tensor::InitKind::Broadcast, bReset,
3271                                getPRNG());
3272   bindings.allocate(Br2)->init(glow::Tensor::InitKind::Broadcast, bReset,
3273                                getPRNG());
3274 
3275   // hidden state
3276   float b = 0.1;
3277   Placeholder *Wxh = getParent()->createPlaceholder(
3278       ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxh", true);
3279   Placeholder *Whh = getParent()->createPlaceholder(
3280       ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whh", true);
3281   Placeholder *Bh1 = getParent()->createPlaceholder(
3282       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bh1", true);
3283   Placeholder *Bh2 = getParent()->createPlaceholder(
3284       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bh2", true);
3285 
3286   bindings.allocate(Wxh)->init(glow::Tensor::InitKind::Xavier, inputSize,
3287                                getPRNG());
3288   bindings.allocate(Whh)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3289                                getPRNG());
3290   bindings.allocate(Bh1)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3291   bindings.allocate(Bh2)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3292 
3293   // Output Layer.
3294   Placeholder *Why = getParent()->createPlaceholder(
3295       ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true);
3296   Placeholder *By = getParent()->createPlaceholder(
3297       ElemKind::FloatTy, {outputSize}, nameBase + ".by", true);
3298 
3299   bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3300                                getPRNG());
3301   bindings.allocate(By)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3302 
3303   auto ty = getParent()->uniqueType(ElemKind::FloatTy, {batchSize, hiddenSize});
3304   auto *Ones = createSplat(nameBase + ".ones", ty, 1.0);
3305 
3306   std::vector<Node *> outputNodes;
3307   for (unsigned t = 0; t < timeSteps; t++) {
3308     auto fc1Name = nameBase + ".fc1." + std::to_string(t);
3309     auto fc2Name = nameBase + ".fc2." + std::to_string(t);
3310     auto add1Name = nameBase + ".add1." + std::to_string(t);
3311     auto sigmoid1Name = nameBase + ".sigmoid1." + std::to_string(t);
3312 
3313     auto *Zt = createSigmoid(
3314         sigmoid1Name,
3315         createAdd(add1Name, createFullyConnected(fc1Name, Ht, Whz, Bz1),
3316                   createFullyConnected(fc2Name, inputs[t], Wxz, Bz2)));
3317 
3318     auto fc3Name = nameBase + ".fc3." + std::to_string(t);
3319     auto fc4Name = nameBase + ".fc4." + std::to_string(t);
3320     auto add2Name = nameBase + ".add2." + std::to_string(t);
3321     auto sigmoid2Name = nameBase + ".sigmoid2." + std::to_string(t);
3322 
3323     auto *Rt = createSigmoid(
3324         sigmoid2Name,
3325         createAdd(add2Name, createFullyConnected(fc3Name, Ht, Whr, Br1),
3326                   createFullyConnected(fc4Name, inputs[t], Wxr, Br2)));
3327 
3328     auto zhtName = nameBase + ".zh." + std::to_string(t);
3329     auto *ZHt = createMul(zhtName, Zt, Ht);
3330 
3331     auto oneMinusZtName = nameBase + ".1-z." + std::to_string(t);
3332     auto *OneMinusZt = createSub(oneMinusZtName, Ones, Zt);
3333 
3334     auto rhtName = nameBase + ".rh." + std::to_string(t);
3335     auto *RHt = createMul(rhtName, Rt, Ht);
3336 
3337     auto fc5Name = nameBase + ".fc5." + std::to_string(t);
3338     auto fc6Name = nameBase + ".fc6." + std::to_string(t);
3339     auto add3Name = nameBase + ".add3." + std::to_string(t);
3340     auto tanh1Name = nameBase + ".tanh1." + std::to_string(t);
3341 
3342     auto *Ut = createTanh(
3343         tanh1Name,
3344         createAdd(add3Name, createFullyConnected(fc5Name, RHt, Whh, Bh1),
3345                   createFullyConnected(fc6Name, inputs[t], Wxh, Bh2)));
3346 
3347     auto oneMinusZtUtName = nameBase + "1.-zu." + std::to_string(t);
3348     auto *OneMinusZtUt = createMul(oneMinusZtUtName, OneMinusZt, Ut);
3349 
3350     auto htName = nameBase + ".H." + std::to_string(t);
3351     Ht = createAdd(htName, ZHt, OneMinusZtUt);
3352 
3353     auto outName = nameBase + ".out." + std::to_string(t);
3354     auto *O = createFullyConnected(outName, Ht, Why, By);
3355     outputs.push_back(O);
3356   }
3357 }
3358 
createSimpleRNN(PlaceholderBindings & bindings,llvm::StringRef namePrefix,llvm::ArrayRef<NodeValue> inputs,unsigned batchSize,unsigned hiddenSize,unsigned outputSize,std::vector<NodeValue> & outputs)3359 void Function::createSimpleRNN(PlaceholderBindings &bindings,
3360                                llvm::StringRef namePrefix,
3361                                llvm::ArrayRef<NodeValue> inputs,
3362                                unsigned batchSize, unsigned hiddenSize,
3363                                unsigned outputSize,
3364                                std::vector<NodeValue> &outputs) {
3365   std::string nameBase = namePrefix;
3366   const unsigned timeSteps = inputs.size();
3367   assert(timeSteps > 0 && "empty input");
3368   const unsigned inputSize = inputs.front().dims().back();
3369   assert(inputSize > 0 && "input dimensionality is zero");
3370 
3371   // Initialize the state to zero.
3372   Placeholder *HInit =
3373       getParent()->createPlaceholder(ElemKind::FloatTy, {batchSize, hiddenSize},
3374                                      nameBase + ".initial_state", false);
3375   bindings.allocate(HInit)->zero();
3376   Node *Ht = HInit;
3377 
3378   float b = 0.1;
3379   Placeholder *Whh = getParent()->createPlaceholder(
3380       ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whh", true);
3381   Placeholder *Bhh = getParent()->createPlaceholder(
3382       ElemKind::FloatTy, {hiddenSize}, nameBase + ".Bhh", true);
3383   Placeholder *Wxh = getParent()->createPlaceholder(
3384       ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxh", true);
3385 
3386   Placeholder *Bxh = getParent()->createPlaceholder(
3387       ElemKind::FloatTy, {hiddenSize}, nameBase + ".Bxh", true);
3388   Placeholder *Why = getParent()->createPlaceholder(
3389       ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true);
3390   Placeholder *Bhy = getParent()->createPlaceholder(
3391       ElemKind::FloatTy, {outputSize}, nameBase + ".Bhy", true);
3392 
3393   bindings.allocate(Whh)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3394                                getPRNG());
3395   bindings.allocate(Bhh)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3396   bindings.allocate(Wxh)->init(glow::Tensor::InitKind::Xavier, inputSize,
3397                                getPRNG());
3398   bindings.allocate(Bxh)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3399   bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3400                                getPRNG());
3401   bindings.allocate(Bhy)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3402 
3403   // Un-roll backpropogation through time as a loop with the shared
3404   // parameters.
3405   for (unsigned t = 0; t < timeSteps; t++) {
3406     auto fc1Name = nameBase + ".fc1." + std::to_string(t);
3407     auto *FC1 = createFullyConnected(fc1Name, Ht, Whh, Bhh);
3408     auto fc2Name = nameBase + ".fc2." + std::to_string(t);
3409     auto *FC2 = createFullyConnected(fc2Name, inputs[t], Wxh, Bxh);
3410     auto aName = nameBase + ".add." + std::to_string(t);
3411     auto *A = createAdd(aName, FC1, FC2);
3412     auto tanhName = nameBase + ".tanh." + std::to_string(t);
3413     auto *H = createTanh(tanhName, A);
3414     auto outName = nameBase + ".out." + std::to_string(t);
3415     auto *O = createFullyConnected(outName, H, Why, Bhy);
3416     outputs.push_back(O);
3417 
3418     Ht = H;
3419   };
3420 }
3421 
createLSTM(PlaceholderBindings & bindings,llvm::StringRef namePrefix,llvm::ArrayRef<NodeValue> inputs,unsigned batchSize,unsigned hiddenSize,unsigned outputSize,std::vector<NodeValue> & outputs)3422 void Function::createLSTM(PlaceholderBindings &bindings,
3423                           llvm::StringRef namePrefix,
3424                           llvm::ArrayRef<NodeValue> inputs, unsigned batchSize,
3425                           unsigned hiddenSize, unsigned outputSize,
3426                           std::vector<NodeValue> &outputs) {
3427   std::string nameBase = namePrefix;
3428   const unsigned timeSteps = inputs.size();
3429   assert(timeSteps > 0 && "empty input");
3430   const unsigned inputSize = inputs.front().dims().back();
3431   assert(inputSize > 0 && "input dimensionality is zero");
3432 
3433   // Initialize the hidden and cell states to zero.
3434   Placeholder *HInit =
3435       getParent()->createPlaceholder(ElemKind::FloatTy, {batchSize, hiddenSize},
3436                                      "initial_hidden_state", false);
3437   bindings.allocate(HInit)->zero();
3438   Node *Ht = HInit;
3439 
3440   Placeholder *CInit = getParent()->createPlaceholder(
3441       ElemKind::FloatTy, {batchSize, hiddenSize}, "initial_cell_state", false);
3442   bindings.allocate(CInit)->zero();
3443   Node *Ct = CInit;
3444 
3445   // Forget gate:
3446   //    F <- sigmoid(Wxf * x + Whf * h + bf)
3447   // Input gate:
3448   //    I <- sigmoid(Wxi * x + Whi * h + bi)
3449   // Output gate:
3450   //    O <- sigmoid(Wxo * x + Who * h + bi)
3451   // Cell state:
3452   //    C <- F . C + I . tanh(Wxc  * x + Whc * h + bc)
3453   // Hidden state:
3454   //    h <- O . tanh(C)
3455 
3456   // forget gate
3457   float bForget = 1.0;
3458   Placeholder *Wxf = getParent()->createPlaceholder(
3459       ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxf", true);
3460   Placeholder *Whf = getParent()->createPlaceholder(
3461       ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whf", true);
3462   Placeholder *Bf1 = getParent()->createPlaceholder(
3463       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bf1", true);
3464   Placeholder *Bf2 = getParent()->createPlaceholder(
3465       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bf2", true);
3466   bindings.allocate(Wxf)->init(glow::Tensor::InitKind::Xavier, inputSize,
3467                                getPRNG());
3468   bindings.allocate(Whf)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3469                                getPRNG());
3470   bindings.allocate(Bf1)->init(glow::Tensor::InitKind::Broadcast, bForget,
3471                                getPRNG());
3472   bindings.allocate(Bf2)->init(glow::Tensor::InitKind::Broadcast, bForget,
3473                                getPRNG());
3474 
3475   // input gate
3476   float bInput = 0.1;
3477   Placeholder *Wxi = getParent()->createPlaceholder(
3478       ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxi", true);
3479   Placeholder *Whi = getParent()->createPlaceholder(
3480       ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whi", true);
3481   Placeholder *Bi1 = getParent()->createPlaceholder(
3482       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bi1", true);
3483   Placeholder *Bi2 = getParent()->createPlaceholder(
3484       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bi2", true);
3485 
3486   bindings.allocate(Wxi)->init(glow::Tensor::InitKind::Xavier, inputSize,
3487                                getPRNG());
3488   bindings.allocate(Whi)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3489                                getPRNG());
3490   bindings.allocate(Bi1)->init(glow::Tensor::InitKind::Broadcast, bInput,
3491                                getPRNG());
3492   bindings.allocate(Bi2)->init(glow::Tensor::InitKind::Broadcast, bInput,
3493                                getPRNG());
3494 
3495   // output gate
3496   float bOutput = 0.1;
3497   Placeholder *Wxo = getParent()->createPlaceholder(
3498       ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxo", true);
3499   Placeholder *Who = getParent()->createPlaceholder(
3500       ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Who", true);
3501   Placeholder *Bo1 = getParent()->createPlaceholder(
3502       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bo1", true);
3503   Placeholder *Bo2 = getParent()->createPlaceholder(
3504       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bo2", true);
3505 
3506   bindings.allocate(Wxo)->init(glow::Tensor::InitKind::Xavier, inputSize,
3507                                getPRNG());
3508   bindings.allocate(Who)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3509                                getPRNG());
3510   bindings.allocate(Bo1)->init(glow::Tensor::InitKind::Broadcast, bOutput,
3511                                getPRNG());
3512   bindings.allocate(Bo2)->init(glow::Tensor::InitKind::Broadcast, bOutput,
3513                                getPRNG());
3514 
3515   // cell state
3516   float bCell = 0.1;
3517   Placeholder *Wxc = getParent()->createPlaceholder(
3518       ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxc", true);
3519   Placeholder *Whc = getParent()->createPlaceholder(
3520       ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whc", true);
3521   Placeholder *Bc1 = getParent()->createPlaceholder(
3522       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bc1", true);
3523   Placeholder *Bc2 = getParent()->createPlaceholder(
3524       ElemKind::FloatTy, {hiddenSize}, nameBase + ".bc2", true);
3525 
3526   bindings.allocate(Wxc)->init(glow::Tensor::InitKind::Xavier, inputSize,
3527                                getPRNG());
3528   bindings.allocate(Whc)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3529                                getPRNG());
3530   bindings.allocate(Bc1)->init(glow::Tensor::InitKind::Broadcast, bCell,
3531                                getPRNG());
3532   bindings.allocate(Bc2)->init(glow::Tensor::InitKind::Broadcast, bCell,
3533                                getPRNG());
3534 
3535   // output layer
3536   float b = 0.1;
3537   Placeholder *Why = getParent()->createPlaceholder(
3538       ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true);
3539   Placeholder *By = getParent()->createPlaceholder(
3540       ElemKind::FloatTy, {outputSize}, nameBase + ".by", true);
3541 
3542   bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3543                                getPRNG());
3544   bindings.allocate(By)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3545 
3546   std::vector<Node *> outputNodes;
3547   for (unsigned t = 0; t < timeSteps; t++) {
3548     auto fc1Name = nameBase + ".fc1." + std::to_string(t);
3549     auto fc2Name = nameBase + ".fc2." + std::to_string(t);
3550     auto add1Name = nameBase + ".add1." + std::to_string(t);
3551     auto sigmoid1Name = nameBase + ".sigmoid1." + std::to_string(t);
3552 
3553     auto *Ft = createSigmoid(
3554         sigmoid1Name,
3555         createAdd(add1Name, createFullyConnected(fc1Name, Ht, Whf, Bf1),
3556                   createFullyConnected(fc2Name, inputs[t], Wxf, Bf2)));
3557 
3558     auto fc3Name = nameBase + ".fc3." + std::to_string(t);
3559     auto fc4Name = nameBase + ".fc4." + std::to_string(t);
3560     auto add2Name = nameBase + ".add2." + std::to_string(t);
3561     auto sigmoid2Name = nameBase + ".sigmoid2." + std::to_string(t);
3562 
3563     auto *It = createSigmoid(
3564         sigmoid2Name,
3565         createAdd(add2Name, createFullyConnected(fc3Name, Ht, Whi, Bi1),
3566                   createFullyConnected(fc4Name, inputs[t], Wxi, Bi2)));
3567 
3568     auto fc5Name = nameBase + ".fc5." + std::to_string(t);
3569     auto fc6Name = nameBase + ".fc6." + std::to_string(t);
3570     auto add3Name = nameBase + ".add3." + std::to_string(t);
3571     auto sigmoid3Name = nameBase + ".sigmoid3." + std::to_string(t);
3572 
3573     auto *Ot = createSigmoid(
3574         sigmoid3Name,
3575         createAdd(add3Name, createFullyConnected(fc5Name, Ht, Who, Bo1),
3576                   createFullyConnected(fc6Name, inputs[t], Wxo, Bo2)));
3577 
3578     auto fc7Name = nameBase + ".fc7." + std::to_string(t);
3579     auto fc8Name = nameBase + ".fc8." + std::to_string(t);
3580     auto add4Name = nameBase + ".add4." + std::to_string(t);
3581     auto tanh1Name = nameBase + ".tanh1." + std::to_string(t);
3582 
3583     auto *CRt = createTanh(
3584         tanh1Name,
3585         createAdd(add4Name, createFullyConnected(fc7Name, Ht, Whc, Bc1),
3586                   createFullyConnected(fc8Name, inputs[t], Wxc, Bc2)));
3587 
3588     auto mul1Name = nameBase + ".mul1." + std::to_string(t);
3589     auto mul2Name = nameBase + ".mul2." + std::to_string(t);
3590     Ct = createAdd(nameBase + ".C." + std::to_string(t),
3591                    createMul(mul1Name, Ft, Ct), createMul(mul2Name, It, CRt));
3592 
3593     auto htName = nameBase + ".H." + std::to_string(t);
3594     auto tanh2Name = nameBase + ".tanh2." + std::to_string(t);
3595     Ht = createMul(htName, Ot, createTanh(tanh2Name, Ct));
3596 
3597     auto outName = nameBase + ".out." + std::to_string(t);
3598     auto *O = createFullyConnected(outName, Ht, Why, By);
3599     outputs.push_back(O);
3600   }
3601 };
3602 
createOnnxRNN(llvm::StringRef namePrefix,NodeValue X,NodeValue W,NodeValue R,NodeValue B,NodeValue initial_h,NodeValue & Y,NodeValue & Y_h,unsigned hiddenSize,RnnDirection direction,std::vector<RnnActivation> & activations)3603 void Function::createOnnxRNN(llvm::StringRef namePrefix, NodeValue X,
3604                              NodeValue W, NodeValue R, NodeValue B,
3605                              NodeValue initial_h, NodeValue &Y, NodeValue &Y_h,
3606                              unsigned hiddenSize, RnnDirection direction,
3607                              std::vector<RnnActivation> &activations) {
3608 
3609 #define RNN_X_SLICE_RANGE(idx)                                                 \
3610   {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize }
3611 #define RNN_W_SLICE_RANGE(idx0, idx1)                                          \
3612   {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize }
3613 #define RNN_R_SLICE_RANGE(idx0, idx1)                                          \
3614   {idx0, idx1 * hiddenSize, 0}, {                                              \
3615     idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize                              \
3616   }
3617 #define RNN_B_SLICE_RANGE(idx0, idx1)                                          \
3618   {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
3619 #define RNN_H_SLICE_RANGE(idx)                                                 \
3620   {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
3621 #define RNN_CREATE_FC(name, LHS, RHS, BIAS)                                    \
3622   BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS)                    \
3623        : (Node *)createMatMul(name, LHS, RHS)
3624 
3625   // Operator name.
3626   const std::string &opName = namePrefix.str();
3627 
3628   // Get all size parameters.
3629   dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1;
3630   assert(X.dims().size() == 3 &&
3631          "ONNX RNN input 'X' should have 3 dimensions!");
3632   dim_t seqLength = X.dims()[0];
3633   dim_t batchSize = X.dims()[1];
3634   dim_t inputSize = X.dims()[2];
3635 
3636   // Validate W size.
3637   assert(W.dims().size() == 3 &&
3638          "ONNX RNN input 'W' should have 3 dimensions!");
3639   assert(W.dims()[0] == numDirections && W.dims()[1] == hiddenSize &&
3640          W.dims()[2] == inputSize && "ONNX RNN 'W' tensor size invalid!");
3641 
3642   // Validate R size.
3643   assert(R.dims().size() == 3 &&
3644          "ONNX RNN input 'R' should have 3 dimensions!");
3645   assert(R.dims()[0] == numDirections && R.dims()[1] == hiddenSize &&
3646          R.dims()[2] == hiddenSize && "ONNX RNN 'R' tensor size invalid!");
3647 
3648   // Validate B size.
3649   if (B.getNode()) {
3650     assert(B.dims().size() == 2 &&
3651            "ONNX RNN input 'B' should have 2 dimensions!");
3652     assert(B.dims()[0] == numDirections && B.dims()[1] == 2 * hiddenSize &&
3653            "ONNX RNN 'B' tensor size invalid!");
3654   }
3655 
3656   // Validate initial_h size.
3657   assert(initial_h.getNode() &&
3658          "ONNX RNN input 'initial_h' is mandatory. Null provided!");
3659   assert(initial_h.dims().size() == 3 &&
3660          "ONNX RNN input 'initial_h' should have 2 dimensions!");
3661   assert(initial_h.dims()[0] == numDirections &&
3662          initial_h.dims()[1] == batchSize &&
3663          initial_h.dims()[2] == hiddenSize &&
3664          "ONNX RNN 'initial_h' tensor size invalid!");
3665 
3666   // Validate number of activations.
3667   assert(activations.size() == numDirections * 1 &&
3668          "ONNX RNN activations vector invalid!");
3669 
3670   // Create X slices.
3671   std::vector<Node *> Xslices;
3672   for (dim_t t = 0; t < seqLength; t++) {
3673     auto XsliceName = opName + ".X" + std::to_string(t) + ".slice";
3674     Node *Xt = createSlice(XsliceName, X, RNN_X_SLICE_RANGE(t));
3675     auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape";
3676     Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize});
3677     Xslices.push_back(Xt);
3678   }
3679 
3680   // Lambda to load forward/backward RNN cell.
3681   auto loadRNNCell = [&](bool forward, std::vector<NodeValue> &Yslices,
3682                          NodeValue &Hslice) {
3683     // Name prefix.
3684     std::string dirLabel = forward ? ".fw" : ".bw";
3685     std::string prefix = opName + ((numDirections > 1) ? dirLabel : "");
3686 
3687     // Slice index used for creating weights slices.
3688     dim_t sliceIdx0 = 0;
3689     if (direction == RnnDirection::Bidirectional) {
3690       sliceIdx0 = forward ? 0 : 1;
3691     }
3692 
3693     // Activations.
3694     size_t activationOffset = sliceIdx0 * 1;
3695     auto activationF = activations[activationOffset + 0];
3696 
3697     // Create W slice (Required).
3698     NodeValue Wi =
3699         createSlice(prefix + ".Wi.", W, RNN_W_SLICE_RANGE(sliceIdx0, 0));
3700     Wi = createReshape(prefix + ".Wi.reshape", Wi, {hiddenSize, inputSize});
3701     Wi = createTranspose(prefix + ".Wi.transp", Wi, {1, 0});
3702 
3703     // Create R slice (Required).
3704     NodeValue Ri =
3705         createSlice(prefix + ".Ri.", R, RNN_R_SLICE_RANGE(sliceIdx0, 0));
3706     Ri = createReshape(prefix + ".Ri.reshape", Ri, {hiddenSize, hiddenSize});
3707     Ri = createTranspose(prefix + ".Ri.transp", Ri, {1, 0});
3708 
3709     // Create B slices (optional).
3710     NodeValue bWi = nullptr;
3711     NodeValue bRi = nullptr;
3712 
3713     if (B) {
3714 
3715       bWi = createSlice(prefix + ".bWi.", B, RNN_B_SLICE_RANGE(sliceIdx0, 0));
3716       bRi = createSlice(prefix + ".bRi.", B, RNN_B_SLICE_RANGE(sliceIdx0, 1));
3717 
3718       bWi = createReshape(prefix + ".bWi.reshape", bWi, {hiddenSize});
3719       bRi = createReshape(prefix + ".bRi.reshape", bRi, {hiddenSize});
3720     }
3721 
3722     // Create H slice for this direction.
3723     Node *Hinit = createSlice(prefix + ".H.slice", initial_h,
3724                               RNN_H_SLICE_RANGE(sliceIdx0));
3725     Hinit =
3726         createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize});
3727 
3728     // Initialize.
3729     Node *Ht = Hinit;
3730 
3731     // Unroll RNN cell for all time steps.
3732     for (size_t t = 0; t < seqLength; t++) {
3733 
3734       // Input for current time step.
3735       // For the reverse RNN cell the inputs are provided in reverse order.
3736       Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t];
3737 
3738       // Hidden state update: Ht = f(Xt * Wi + bWi + Ht-1 * Ri + bRi).
3739       Ht = createAdd(prefix + ".H.add",
3740                      RNN_CREATE_FC(prefix + ".H.fc1", Xt, Wi, bWi),
3741                      RNN_CREATE_FC(prefix + ".H.fc2", Ht, Ri, bRi));
3742       Ht = activationF(prefix + ".H.act", Ht);
3743 
3744       // Output.
3745       Yslices.push_back(Ht);
3746     }
3747 
3748     // Updated states nodes.
3749     Hslice = Ht;
3750   }; // End of local lambda "loadRNNCell".
3751 
3752   bool forwardEnabled = ((direction == RnnDirection::Forward) ||
3753                          (direction == RnnDirection::Bidirectional));
3754   bool backwardEnabled = ((direction == RnnDirection::Reverse) ||
3755                           (direction == RnnDirection::Bidirectional));
3756 
3757   std::vector<NodeValue> YSlices;
3758   std::vector<NodeValue> Hslices;
3759 
3760   // Load forward RNN.
3761   std::vector<NodeValue> forwardYslices;
3762   if (forwardEnabled) {
3763     NodeValue forwardHslice;
3764     loadRNNCell(/* forward */ true, forwardYslices, forwardHslice);
3765     Hslices.push_back(forwardHslice);
3766   }
3767 
3768   // Load backward RNN.
3769   std::vector<NodeValue> backwardYslices;
3770   if (backwardEnabled) {
3771     NodeValue backwardHslice;
3772     loadRNNCell(/* forward */ false, backwardYslices, backwardHslice);
3773     Hslices.push_back(backwardHslice);
3774   }
3775 
3776   // Gather Y slices.
3777   for (size_t t = 0; t < seqLength; t++) {
3778     if (forwardEnabled) {
3779       YSlices.push_back(forwardYslices[t]);
3780     }
3781     if (backwardEnabled) {
3782       YSlices.push_back(backwardYslices[seqLength - 1 - t]);
3783     }
3784   }
3785 
3786   // Concatenate Y slices.
3787   // Y size is [seqLength, numDirections, batchSize, hiddenSize].
3788   Y = createReshape(opName + ".Y.reshape",
3789                     createConcat(opName + ".Y.concat", YSlices, 0),
3790                     {seqLength, numDirections, batchSize, hiddenSize});
3791 
3792   // Concatenate Y_h slices.
3793   // Y_h size is [numDirections, batchSize, hiddenSize].
3794   Y_h = createReshape(opName + ".Y_h.reshape",
3795                       createConcat(opName + ".Y_h.concat", Hslices, 0),
3796                       {numDirections, batchSize, hiddenSize});
3797 
3798 #undef RNN_X_SLICE_RANGE
3799 #undef RNN_W_SLICE_RANGE
3800 #undef RNN_R_SLICE_RANGE
3801 #undef RNN_B_SLICE_RANGE
3802 #undef RNN_H_SLICE_RANGE
3803 #undef RNN_CREATE_FC
3804 }
3805 
createOnnxGRU(llvm::StringRef namePrefix,NodeValue X,NodeValue W,NodeValue R,NodeValue B,NodeValue initial_h,NodeValue & Y,NodeValue & Y_h,unsigned hiddenSize,RnnDirection direction,std::vector<RnnActivation> & activations,bool linearBeforeReset)3806 void Function::createOnnxGRU(llvm::StringRef namePrefix, NodeValue X,
3807                              NodeValue W, NodeValue R, NodeValue B,
3808                              NodeValue initial_h, NodeValue &Y, NodeValue &Y_h,
3809                              unsigned hiddenSize, RnnDirection direction,
3810                              std::vector<RnnActivation> &activations,
3811                              bool linearBeforeReset) {
3812 
3813 #define GRU_X_SLICE_RANGE(idx)                                                 \
3814   {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize }
3815 #define GRU_W_SLICE_RANGE(idx0, idx1)                                          \
3816   {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize }
3817 #define GRU_R_SLICE_RANGE(idx0, idx1)                                          \
3818   {idx0, idx1 * hiddenSize, 0}, {                                              \
3819     idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize                              \
3820   }
3821 #define GRU_B_SLICE_RANGE(idx0, idx1)                                          \
3822   {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
3823 #define GRU_H_SLICE_RANGE(idx)                                                 \
3824   {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
3825 #define GRU_CREATE_FC(name, LHS, RHS, BIAS)                                    \
3826   BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS)                    \
3827        : (Node *)createMatMul(name, LHS, RHS)
3828 
3829   // Operator name.
3830   const std::string &opName = namePrefix.str();
3831 
3832   // Get all size parameters.
3833   dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1;
3834   assert(X.dims().size() == 3 &&
3835          "ONNX GRU input 'X' should have 3 dimensions!");
3836   dim_t seqLength = X.dims()[0];
3837   dim_t batchSize = X.dims()[1];
3838   dim_t inputSize = X.dims()[2];
3839 
3840   // Validate W size.
3841   assert(W.dims().size() == 3 &&
3842          "ONNX GRU input 'W' should have 3 dimensions!");
3843   assert(W.dims()[0] == numDirections && W.dims()[1] == 3 * hiddenSize &&
3844          W.dims()[2] == inputSize && "ONNX GRU 'W' tensor size invalid!");
3845 
3846   // Validate R size.
3847   assert(R.dims().size() == 3 &&
3848          "ONNX GRU input 'R' should have 3 dimensions!");
3849   assert(R.dims()[0] == numDirections && R.dims()[1] == 3 * hiddenSize &&
3850          R.dims()[2] == hiddenSize && "ONNX GRU 'R' tensor size invalid!");
3851 
3852   // Validate B size.
3853   if (B.getNode()) {
3854     assert(B.dims().size() == 2 &&
3855            "ONNX GRU input 'B' should have 2 dimensions!");
3856     assert(B.dims()[0] == numDirections && B.dims()[1] == 6 * hiddenSize &&
3857            "ONNX GRU 'B' tensor size invalid!");
3858   }
3859 
3860   // Validate initial_h size.
3861   assert(initial_h.getNode() &&
3862          "ONNX GRU input 'initial_h' is mandatory. Null provided!");
3863   assert(initial_h.dims().size() == 3 &&
3864          "ONNX GRU input 'initial_h' should have 2 dimensions!");
3865   assert(initial_h.dims()[0] == numDirections &&
3866          initial_h.dims()[1] == batchSize &&
3867          initial_h.dims()[2] == hiddenSize &&
3868          "ONNX GRU 'initial_h' tensor size invalid!");
3869 
3870   // Validate number of activations.
3871   assert(activations.size() == numDirections * 2 &&
3872          "ONNX GRU activations vector invalid!");
3873 
3874   // Create X slices.
3875   std::vector<Node *> Xslices;
3876   for (dim_t t = 0; t < seqLength; t++) {
3877     auto XsliceName = opName + ".X" + std::to_string(t) + ".slice";
3878     Node *Xt = createSlice(XsliceName, X, GRU_X_SLICE_RANGE(t));
3879     auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape";
3880     Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize});
3881     Xslices.push_back(Xt);
3882   }
3883 
3884   // Lambda to load forward/backward GRU cell.
3885   auto loadGRUCell = [&](bool forward, std::vector<NodeValue> &Yslices,
3886                          NodeValue &Hslice) {
3887     // Name prefix.
3888     std::string dirLabel = forward ? ".fw" : ".bw";
3889     std::string prefix = opName + ((numDirections > 1) ? dirLabel : "");
3890 
3891     // Slice index used for creating weights slices.
3892     dim_t sliceIdx0 = 0;
3893     if (direction == RnnDirection::Bidirectional) {
3894       sliceIdx0 = forward ? 0 : 1;
3895     }
3896 
3897     // Activations.
3898     size_t activationOffset = sliceIdx0 * 2;
3899     auto activationF = activations[activationOffset + 0];
3900     auto activationG = activations[activationOffset + 1];
3901 
3902     // Create W slices (Required).
3903     NodeValue Wz =
3904         createSlice(prefix + ".Wz.", W, GRU_W_SLICE_RANGE(sliceIdx0, 0));
3905     NodeValue Wr =
3906         createSlice(prefix + ".Wr.", W, GRU_W_SLICE_RANGE(sliceIdx0, 1));
3907     NodeValue Wh =
3908         createSlice(prefix + ".Wh.", W, GRU_W_SLICE_RANGE(sliceIdx0, 2));
3909 
3910     Wz = createReshape(prefix + ".Wz.reshape", Wz, {hiddenSize, inputSize});
3911     Wr = createReshape(prefix + ".Wr.reshape", Wr, {hiddenSize, inputSize});
3912     Wh = createReshape(prefix + ".Wh.reshape", Wh, {hiddenSize, inputSize});
3913 
3914     Wz = createTranspose(prefix + ".Wz.transp", Wz, {1, 0});
3915     Wr = createTranspose(prefix + ".Wr.transp", Wr, {1, 0});
3916     Wh = createTranspose(prefix + ".Wh.transp", Wh, {1, 0});
3917 
3918     // Create R slices (Required).
3919     NodeValue Rz =
3920         createSlice(prefix + ".Rz.", R, GRU_R_SLICE_RANGE(sliceIdx0, 0));
3921     NodeValue Rr =
3922         createSlice(prefix + ".Rr.", R, GRU_R_SLICE_RANGE(sliceIdx0, 1));
3923     NodeValue Rh =
3924         createSlice(prefix + ".Rh.", R, GRU_R_SLICE_RANGE(sliceIdx0, 2));
3925 
3926     Rz = createReshape(prefix + ".Rz.reshape", Rz, {hiddenSize, hiddenSize});
3927     Rr = createReshape(prefix + ".Rr.reshape", Rr, {hiddenSize, hiddenSize});
3928     Rh = createReshape(prefix + ".Rh.reshape", Rh, {hiddenSize, hiddenSize});
3929 
3930     Rz = createTranspose(prefix + ".Rz.transp", Rz, {1, 0});
3931     Rr = createTranspose(prefix + ".Rr.transp", Rr, {1, 0});
3932     Rh = createTranspose(prefix + ".Rh.transp", Rh, {1, 0});
3933 
3934     // Create B slices (optional).
3935     NodeValue bWz = nullptr;
3936     NodeValue bWr = nullptr;
3937     NodeValue bWh = nullptr;
3938     NodeValue bRz = nullptr;
3939     NodeValue bRr = nullptr;
3940     NodeValue bRh = nullptr;
3941 
3942     if (B) {
3943 
3944       bWz = createSlice(prefix + ".bWz.", B, GRU_B_SLICE_RANGE(sliceIdx0, 0));
3945       bWr = createSlice(prefix + ".bWr.", B, GRU_B_SLICE_RANGE(sliceIdx0, 1));
3946       bWh = createSlice(prefix + ".bWh.", B, GRU_B_SLICE_RANGE(sliceIdx0, 2));
3947       bRz = createSlice(prefix + ".bRz.", B, GRU_B_SLICE_RANGE(sliceIdx0, 3));
3948       bRr = createSlice(prefix + ".bRr.", B, GRU_B_SLICE_RANGE(sliceIdx0, 4));
3949       bRh = createSlice(prefix + ".bRh.", B, GRU_B_SLICE_RANGE(sliceIdx0, 5));
3950 
3951       bWz = createReshape(prefix + ".bWz.reshape", bWz, {hiddenSize});
3952       bWr = createReshape(prefix + ".bWr.reshape", bWr, {hiddenSize});
3953       bWh = createReshape(prefix + ".bWh.reshape", bWh, {hiddenSize});
3954       bRz = createReshape(prefix + ".bRz.reshape", bRz, {hiddenSize});
3955       bRr = createReshape(prefix + ".bRr.reshape", bRr, {hiddenSize});
3956       bRh = createReshape(prefix + ".bRh.reshape", bRh, {hiddenSize});
3957     }
3958 
3959     // Create H slice for this direction.
3960     Node *Hinit = createSlice(prefix + ".H.slice", initial_h,
3961                               GRU_H_SLICE_RANGE(sliceIdx0));
3962     Hinit =
3963         createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize});
3964 
3965     // Initialize.
3966     Node *Ht = Hinit;
3967 
3968     // Unroll GRU cell for all time steps.
3969     for (size_t t = 0; t < seqLength; t++) {
3970 
3971       // Input for current time step.
3972       // For the reverse GRU cell the inputs are provided in reverse order.
3973       Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t];
3974 
3975       // Update gate: zt = f(Xt * Wz + bWz + Ht-1 * Rz + bRz).
3976       Node *zt = createAdd(prefix + ".Z.add1",
3977                            GRU_CREATE_FC(prefix + ".Z.fc1", Xt, Wz, bWz),
3978                            GRU_CREATE_FC(prefix + ".Z.fc2", Ht, Rz, bRz));
3979       zt = activationF(prefix + ".Z.act", zt);
3980 
3981       // Reset gate: rt = f(Xt * Wr + bWr + Ht-1 * Rr + bRr).
3982       Node *rt = createAdd(prefix + ".R.add1",
3983                            GRU_CREATE_FC(prefix + ".R.fc1", Xt, Wr, bWr),
3984                            GRU_CREATE_FC(prefix + ".R.fc2", Ht, Rr, bRr));
3985       rt = activationF(prefix + ".R.act", rt);
3986 
3987       // Hidden gate:
3988       // For linearBeforeReset = true:
3989       //   htild = g(Xt * Wh + bWh + rt . (Ht-1 * Rh + bRh)).
3990       // For linearBeforeReset = false:
3991       //   htild = g(Xt * Wh + bWh + (rt . Ht-1) * Rh + bRh).
3992       Node *htild;
3993       if (linearBeforeReset) {
3994         htild = createAdd(
3995             prefix + ".Htild.add",
3996             GRU_CREATE_FC(prefix + ".Htild.fc1", Xt, Wh, bWh),
3997             createMul(prefix + ".Htild.reset", rt,
3998                       GRU_CREATE_FC(prefix + ".Htild.fc2", Ht, Rh, bRh)));
3999       } else {
4000         htild = createAdd(
4001             prefix + ".Htild.add",
4002             GRU_CREATE_FC(prefix + ".Htild.fc1", Xt, Wh, bWh),
4003             GRU_CREATE_FC(prefix + ".Htild.fc2",
4004                           createMul(prefix + ".Htild.reset", rt, Ht), Rh, bRh));
4005       }
4006       htild = activationG(prefix + ".Htild.act", htild);
4007 
4008       // Hidden state update:
4009       // Ht = (1 - zt) . htild + zt . Ht-1 = htild - zt . htild + zt . Ht-1.
4010       Ht = createAdd(prefix + ".H.add",
4011                      createSub(prefix + ".H.sub", htild,
4012                                createMul(prefix + ".H.mult1", zt, htild)),
4013                      createMul(prefix + ".H.mult2", zt, Ht));
4014 
4015       // Output.
4016       Yslices.push_back(Ht);
4017     }
4018 
4019     // Updated states nodes.
4020     Hslice = Ht;
4021   }; // End of local lambda "loadGRUCell".
4022 
4023   bool forwardEnabled = ((direction == RnnDirection::Forward) ||
4024                          (direction == RnnDirection::Bidirectional));
4025   bool backwardEnabled = ((direction == RnnDirection::Reverse) ||
4026                           (direction == RnnDirection::Bidirectional));
4027 
4028   std::vector<NodeValue> YSlices;
4029   std::vector<NodeValue> Hslices;
4030 
4031   // Load forward GRU.
4032   std::vector<NodeValue> forwardYslices;
4033   if (forwardEnabled) {
4034     NodeValue forwardHslice;
4035     loadGRUCell(/* forward */ true, forwardYslices, forwardHslice);
4036     Hslices.push_back(forwardHslice);
4037   }
4038 
4039   // Load backward GRU.
4040   std::vector<NodeValue> backwardYslices;
4041   if (backwardEnabled) {
4042     NodeValue backwardHslice;
4043     loadGRUCell(/* forward */ false, backwardYslices, backwardHslice);
4044     Hslices.push_back(backwardHslice);
4045   }
4046 
4047   // Gather Y slices.
4048   for (size_t t = 0; t < seqLength; t++) {
4049     if (forwardEnabled) {
4050       YSlices.push_back(forwardYslices[t]);
4051     }
4052     if (backwardEnabled) {
4053       YSlices.push_back(backwardYslices[seqLength - 1 - t]);
4054     }
4055   }
4056 
4057   // Concatenate Y slices.
4058   // Y size is [seqLength, numDirections, batchSize, hiddenSize].
4059   Y = createReshape(opName + ".Y.reshape",
4060                     createConcat(opName + ".Y.concat", YSlices, 0),
4061                     {seqLength, numDirections, batchSize, hiddenSize});
4062 
4063   // Concatenate Y_h slices.
4064   // Y_h size is [numDirections, batchSize, hiddenSize].
4065   Y_h = createReshape(opName + ".Y_h.reshape",
4066                       createConcat(opName + ".Y_h.concat", Hslices, 0),
4067                       {numDirections, batchSize, hiddenSize});
4068 
4069 #undef GRU_X_SLICE_RANGE
4070 #undef GRU_W_SLICE_RANGE
4071 #undef GRU_R_SLICE_RANGE
4072 #undef GRU_B_SLICE_RANGE
4073 #undef GRU_H_SLICE_RANGE
4074 #undef GRU_CREATE_FC
4075 }
4076 
createOnnxLSTM(llvm::StringRef namePrefix,NodeValue X,NodeValue W,NodeValue R,NodeValue B,NodeValue initial_h,NodeValue initial_c,NodeValue P,NodeValue & Y,NodeValue & Y_h,NodeValue & Y_c,unsigned hiddenSize,RnnDirection direction,std::vector<RnnActivation> & activations,bool inputForget)4077 void Function::createOnnxLSTM(llvm::StringRef namePrefix, NodeValue X,
4078                               NodeValue W, NodeValue R, NodeValue B,
4079                               NodeValue initial_h, NodeValue initial_c,
4080                               NodeValue P, NodeValue &Y, NodeValue &Y_h,
4081                               NodeValue &Y_c, unsigned hiddenSize,
4082                               RnnDirection direction,
4083                               std::vector<RnnActivation> &activations,
4084                               bool inputForget) {
4085 
4086 #define LSTM_X_SLICE_RANGE(idx)                                                \
4087   {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize }
4088 #define LSTM_H_SLICE_RANGE(idx)                                                \
4089   {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
4090 #define LSTM_C_SLICE_RANGE(idx)                                                \
4091   {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
4092 #define LSTM_W_SLICE_RANGE(idx0, idx1)                                         \
4093   {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize }
4094 #define LSTM_R_SLICE_RANGE(idx0, idx1)                                         \
4095   {idx0, idx1 * hiddenSize, 0}, {                                              \
4096     idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize                              \
4097   }
4098 #define LSTM_B_SLICE_RANGE(idx0, idx1)                                         \
4099   {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
4100 #define LSTM_P_SLICE_RANGE(idx0, idx1)                                         \
4101   {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
4102 #define LSTM_CREATE_FC(name, LHS, RHS, BIAS)                                   \
4103   BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS)                    \
4104        : (Node *)createMatMul(name, LHS, RHS)
4105 
4106   // Operator name.
4107   const std::string &opName = namePrefix.str();
4108 
4109   // Get all size parameters.
4110   dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1;
4111   assert(X.dims().size() == 3 &&
4112          "ONNX LSTM input 'X' should have 3 dimensions!");
4113   dim_t seqLength = X.dims()[0];
4114   dim_t batchSize = X.dims()[1];
4115   dim_t inputSize = X.dims()[2];
4116 
4117   // Validate W size.
4118   assert(W.dims().size() == 3 &&
4119          "ONNX LSTM input 'W' should have 3 dimensions!");
4120   assert(W.dims()[0] == numDirections && W.dims()[1] == 4 * hiddenSize &&
4121          W.dims()[2] == inputSize && "ONNX LSTM 'W' tensor size invalid!");
4122 
4123   // Validate R size.
4124   assert(R.dims().size() == 3 &&
4125          "ONNX LSTM input 'R' should have 3 dimensions!");
4126   assert(R.dims()[0] == numDirections && R.dims()[1] == 4 * hiddenSize &&
4127          R.dims()[2] == hiddenSize && "ONNX LSTM 'R' tensor size invalid!");
4128 
4129   // Validate B size.
4130   if (B.getNode()) {
4131     assert(B.dims().size() == 2 &&
4132            "ONNX LSTM input 'B' should have 2 dimensions!");
4133     assert(B.dims()[0] == numDirections && B.dims()[1] == 8 * hiddenSize &&
4134            "ONNX LSTM 'B' tensor size invalid!");
4135   }
4136 
4137   // Validate initial_h size.
4138   assert(initial_h.getNode() &&
4139          "ONNX LSTM input 'initial_h' is mandatory. Null provided!");
4140   assert(initial_h.dims().size() == 3 &&
4141          "ONNX LSTM input 'initial_h' should have 2 dimensions!");
4142   assert(initial_h.dims()[0] == numDirections &&
4143          initial_h.dims()[1] == batchSize &&
4144          initial_h.dims()[2] == hiddenSize &&
4145          "ONNX LSTM 'initial_h' tensor size invalid!");
4146 
4147   // Validate initial_c size.
4148   assert(initial_c.getNode() &&
4149          "ONNX LSTM input 'initial_c' is mandatory. Null provided!");
4150   assert(initial_c.dims().size() == 3 &&
4151          "ONNX LSTM input 'initial_c' should have 2 dimensions!");
4152   assert(initial_c.dims()[0] == numDirections &&
4153          initial_c.dims()[1] == batchSize &&
4154          initial_c.dims()[2] == hiddenSize &&
4155          "ONNX LSTM 'initial_c' tensor size invalid!");
4156 
4157   // Validate P size.
4158   if (P.getNode()) {
4159     assert(P.dims().size() == 2 &&
4160            "ONNX LSTM input 'P' should have 2 dimensions!");
4161     assert(P.dims()[0] == numDirections && P.dims()[1] == 3 * hiddenSize &&
4162            "ONNX LSTM 'P' tensor size invalid!");
4163   }
4164 
4165   // Validate number of activations.
4166   assert(activations.size() == numDirections * 3 &&
4167          "ONNX LSTM activations vector invalid!");
4168 
4169   // Create X slices.
4170   std::vector<Node *> Xslices;
4171   for (dim_t t = 0; t < seqLength; t++) {
4172     auto XsliceName = opName + ".X" + std::to_string(t) + ".slice";
4173     Node *Xt = createSlice(XsliceName, X, LSTM_X_SLICE_RANGE(t));
4174     auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape";
4175     Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize});
4176     Xslices.push_back(Xt);
4177   }
4178 
4179   // Lambda to load forward/backward LSTM cell.
4180   auto loadLSTMCell = [&](bool forward, std::vector<NodeValue> &Yslices,
4181                           NodeValue &Hslice, NodeValue &Cslice) {
4182     // Name prefix.
4183     std::string dirLabel = forward ? ".fw" : ".bw";
4184     std::string prefix = opName + ((numDirections > 1) ? dirLabel : "");
4185 
4186     // Slice index used for creating weights slices.
4187     dim_t sliceIdx0 = 0;
4188     if (direction == RnnDirection::Bidirectional) {
4189       sliceIdx0 = forward ? 0 : 1;
4190     }
4191 
4192     // Activations.
4193     size_t activationOffset = sliceIdx0 * 3;
4194     auto activationF = activations[activationOffset + 0];
4195     auto activationG = activations[activationOffset + 1];
4196     auto activationH = activations[activationOffset + 2];
4197 
4198     // Create W slices (Required).
4199     NodeValue Wi =
4200         createSlice(prefix + ".Wi.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 0));
4201     NodeValue Wo =
4202         createSlice(prefix + ".Wo.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 1));
4203     NodeValue Wf =
4204         createSlice(prefix + ".Wf.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 2));
4205     NodeValue Wc =
4206         createSlice(prefix + ".Wc.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 3));
4207 
4208     Wi = createReshape(prefix + ".Wi.reshape", Wi, {hiddenSize, inputSize});
4209     Wo = createReshape(prefix + ".Wo.reshape", Wo, {hiddenSize, inputSize});
4210     Wf = createReshape(prefix + ".Wf.reshape", Wf, {hiddenSize, inputSize});
4211     Wc = createReshape(prefix + ".Wc.reshape", Wc, {hiddenSize, inputSize});
4212 
4213     Wi = createTranspose(prefix + ".Wi.transp", Wi, {1, 0});
4214     Wo = createTranspose(prefix + ".Wo.transp", Wo, {1, 0});
4215     Wf = createTranspose(prefix + ".Wf.transp", Wf, {1, 0});
4216     Wc = createTranspose(prefix + ".Wc.transp", Wc, {1, 0});
4217 
4218     // Create R slices (Required).
4219     NodeValue Ri =
4220         createSlice(prefix + ".Ri.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 0));
4221     NodeValue Ro =
4222         createSlice(prefix + ".Ro.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 1));
4223     NodeValue Rf =
4224         createSlice(prefix + ".Rf.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 2));
4225     NodeValue Rc =
4226         createSlice(prefix + ".Rc.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 3));
4227 
4228     Ri = createReshape(prefix + ".Ri.reshape", Ri, {hiddenSize, hiddenSize});
4229     Ro = createReshape(prefix + ".Ro.reshape", Ro, {hiddenSize, hiddenSize});
4230     Rf = createReshape(prefix + ".Rf.reshape", Rf, {hiddenSize, hiddenSize});
4231     Rc = createReshape(prefix + ".Rc.reshape", Rc, {hiddenSize, hiddenSize});
4232 
4233     Ri = createTranspose(prefix + ".Ri.transp", Ri, {1, 0});
4234     Ro = createTranspose(prefix + ".Ro.transp", Ro, {1, 0});
4235     Rf = createTranspose(prefix + ".Rf.transp", Rf, {1, 0});
4236     Rc = createTranspose(prefix + ".Rc.transp", Rc, {1, 0});
4237 
4238     // Create B slices (optional).
4239     NodeValue bWi = nullptr;
4240     NodeValue bWo = nullptr;
4241     NodeValue bWf = nullptr;
4242     NodeValue bWc = nullptr;
4243     NodeValue bRi = nullptr;
4244     NodeValue bRo = nullptr;
4245     NodeValue bRf = nullptr;
4246     NodeValue bRc = nullptr;
4247 
4248     if (B) {
4249 
4250       bWi = createSlice(prefix + ".bWi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 0));
4251       bWo = createSlice(prefix + ".bWo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 1));
4252       bWf = createSlice(prefix + ".bWf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 2));
4253       bWc = createSlice(prefix + ".bWc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 3));
4254       bRi = createSlice(prefix + ".bRi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 4));
4255       bRo = createSlice(prefix + ".bRo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 5));
4256       bRf = createSlice(prefix + ".bRf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 6));
4257       bRc = createSlice(prefix + ".bRc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 7));
4258 
4259       bWi = createReshape(prefix + ".bWi.reshape", bWi, {hiddenSize});
4260       bWo = createReshape(prefix + ".bWo.reshape", bWo, {hiddenSize});
4261       bWf = createReshape(prefix + ".bWf.reshape", bWf, {hiddenSize});
4262       bWc = createReshape(prefix + ".bWc.reshape", bWc, {hiddenSize});
4263       bRi = createReshape(prefix + ".bRi.reshape", bRi, {hiddenSize});
4264       bRo = createReshape(prefix + ".bRo.reshape", bRo, {hiddenSize});
4265       bRf = createReshape(prefix + ".bRf.reshape", bRf, {hiddenSize});
4266       bRc = createReshape(prefix + ".bRc.reshape", bRc, {hiddenSize});
4267     }
4268 
4269     // Create P slices (optional).
4270     NodeValue Pi = nullptr;
4271     NodeValue Po = nullptr;
4272     NodeValue Pf = nullptr;
4273 
4274     if (P) {
4275 
4276       Pi = createSlice(prefix + ".Pi.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 0));
4277       Po = createSlice(prefix + ".Po.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 1));
4278       Pf = createSlice(prefix + ".Pf.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 2));
4279 
4280       // Repeat P slices to match [batchSize, hiddenSize].
4281       Pi = createTile(prefix + ".Pi.repeat", Pi, batchSize, 0);
4282       Po = createTile(prefix + ".Po.repeat", Po, batchSize, 0);
4283       Pf = createTile(prefix + ".Pf.repeat", Pf, batchSize, 0);
4284     }
4285 
4286     // Create H slice for this direction.
4287     Node *Hinit = createSlice(prefix + ".H.slice", initial_h,
4288                               LSTM_H_SLICE_RANGE(sliceIdx0));
4289     Hinit =
4290         createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize});
4291 
4292     // Create C slice for this direction.
4293     Node *Cinit = createSlice(prefix + ".C.slice", initial_c,
4294                               LSTM_C_SLICE_RANGE(sliceIdx0));
4295     Cinit =
4296         createReshape(prefix + ".C.reshape", Cinit, {batchSize, hiddenSize});
4297 
4298     // Initialize.
4299     Node *Ht = Hinit;
4300     Node *Ct = Cinit;
4301 
4302     // Unroll LSTM cell for all time steps.
4303     for (size_t t = 0; t < seqLength; t++) {
4304 
4305       // Input for current time step.
4306       // For the reverse LSTM cell the inputs are provided in reverse order.
4307       Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t];
4308 
4309       // Forget gate: ft = f(Xt * Wf + bWf + Ht-1 * Rf + bRf + Pf . Ct-1).
4310       Node *ft = createAdd(prefix + ".F.add1",
4311                            LSTM_CREATE_FC(prefix + ".F.fc1", Xt, Wf, bWf),
4312                            LSTM_CREATE_FC(prefix + ".F.fc2", Ht, Rf, bRf));
4313       if (Pf) {
4314         ft = createAdd(prefix + ".F.add2", ft,
4315                        createMul(prefix + ".F.mult", Pf, Ct));
4316       }
4317       ft = activationF(prefix + ".F.act", ft);
4318 
4319       // Cell state candidate: ctild = g(Xt * Wc + bWc + Ht-1 * Rc + bRc).
4320       Node *ctild =
4321           createAdd(prefix + ".Ctild.add",
4322                     LSTM_CREATE_FC(prefix + ".Ctild.fc1", Xt, Wc, bWc),
4323                     LSTM_CREATE_FC(prefix + ".Ctild.fc2", Ht, Rc, bRc));
4324       ctild = activationG(prefix + ".Ctild.act", ctild);
4325 
4326       // Input gate:
4327       // For inputForget == true:
4328       //   it = 1 - ft.
4329       // For inputForget == false:
4330       //   it = f(Xt * Wi + bWi + Ht-1 * Ri + bRi + Pi . Ct-1).
4331       Node *it;
4332       if (inputForget) {
4333         auto splatTy = ft->getNthResult(0).getType();
4334         it = createSub(prefix + ".I.sub",
4335                        createSplat(prefix + ".I.splat", splatTy, 1.0), ft);
4336       } else {
4337         it = createAdd(prefix + ".I.add1",
4338                        LSTM_CREATE_FC(prefix + ".I.fc1", Xt, Wi, bWi),
4339                        LSTM_CREATE_FC(prefix + ".I.fc2", Ht, Ri, bRi));
4340         if (Pi) {
4341           it = createAdd(prefix + ".I.add2", it,
4342                          createMul(prefix + ".I.mult", Pi, Ct));
4343         }
4344         it = activationF(prefix + ".I.act", it);
4345       }
4346 
4347       // Cell state update: Ct = ft . Ct-1 + it . ctild.
4348       Ct = createAdd(prefix + ".C.add", createMul(prefix + ".C.mult1", ft, Ct),
4349                      createMul(prefix + ".C.mult2", it, ctild));
4350 
4351       // Output gate: ot = f(Xt * Wo + bWo + Ht-1 * Ro + bRo + Po . Ct).
4352       Node *ot = createAdd(prefix + ".O.add1",
4353                            LSTM_CREATE_FC(prefix + ".O.fc1", Xt, Wo, bWo),
4354                            LSTM_CREATE_FC(prefix + ".O.fc2", Ht, Ro, bRo));
4355       if (Po) {
4356         ot = createAdd(prefix + ".O.add2", ot,
4357                        createMul(prefix + ".O.mult", Po, Ct));
4358       }
4359       ot = activationF(prefix + ".O.act", ot);
4360 
4361       // Hidden state update: Ht = ot . h(Ct).
4362       Ht =
4363           createMul(prefix + ".H.mult", ot, activationH(prefix + ".H.act", Ct));
4364 
4365       // Output.
4366       Yslices.push_back(Ht);
4367     }
4368 
4369     // Updated states nodes.
4370     Hslice = Ht;
4371     Cslice = Ct;
4372   }; // End of local lambda "loadLSTMCell".
4373 
4374   bool forwardEnabled = ((direction == RnnDirection::Forward) ||
4375                          (direction == RnnDirection::Bidirectional));
4376   bool backwardEnabled = ((direction == RnnDirection::Reverse) ||
4377                           (direction == RnnDirection::Bidirectional));
4378 
4379   std::vector<NodeValue> YSlices;
4380   std::vector<NodeValue> Hslices;
4381   std::vector<NodeValue> Cslices;
4382 
4383   // Load forward LSTM.
4384   std::vector<NodeValue> forwardYslices;
4385   if (forwardEnabled) {
4386     NodeValue forwardHslice;
4387     NodeValue forwardCslice;
4388     loadLSTMCell(/* forward */ true, forwardYslices, forwardHslice,
4389                  forwardCslice);
4390     Hslices.push_back(forwardHslice);
4391     Cslices.push_back(forwardCslice);
4392   }
4393 
4394   // Load backward LSTM.
4395   std::vector<NodeValue> backwardYslices;
4396   if (backwardEnabled) {
4397     NodeValue backwardHslice;
4398     NodeValue backwardCslice;
4399     loadLSTMCell(/* forward */ false, backwardYslices, backwardHslice,
4400                  backwardCslice);
4401     Hslices.push_back(backwardHslice);
4402     Cslices.push_back(backwardCslice);
4403   }
4404 
4405   // Gather Y slices.
4406   for (size_t t = 0; t < seqLength; t++) {
4407     if (forwardEnabled) {
4408       YSlices.push_back(forwardYslices[t]);
4409     }
4410     if (backwardEnabled) {
4411       YSlices.push_back(backwardYslices[seqLength - 1 - t]);
4412     }
4413   }
4414 
4415   // Concatenate Y slices.
4416   // Y size is [seqLength, numDirections, batchSize, hiddenSize].
4417   Y = createReshape(opName + ".Y.reshape",
4418                     createConcat(opName + ".Y.concat", YSlices, 0),
4419                     {seqLength, numDirections, batchSize, hiddenSize});
4420 
4421   // Concatenate Y_h slices.
4422   // Y_h size is [numDirections, batchSize, hiddenSize].
4423   Y_h = createReshape(opName + ".Y_h.reshape",
4424                       createConcat(opName + ".Y_h.concat", Hslices, 0),
4425                       {numDirections, batchSize, hiddenSize});
4426 
4427   // Concatenate Y_c slices.
4428   // Y_c size is [numDirections, batchSize, hiddenSize].
4429   Y_c = createReshape(opName + ".Y_c.reshape",
4430                       createConcat(opName + ".Y_c.concat", Cslices, 0),
4431                       {numDirections, batchSize, hiddenSize});
4432 
4433 #undef LSTM_X_SLICE_RANGE
4434 #undef LSTM_H_SLICE_RANGE
4435 #undef LSTM_C_SLICE_RANGE
4436 #undef LSTM_W_SLICE_RANGE
4437 #undef LSTM_R_SLICE_RANGE
4438 #undef LSTM_B_SLICE_RANGE
4439 #undef LSTM_P_SLICE_RANGE
4440 #undef LSTM_CREATE_FC
4441 }
4442 
createTraceEvent(llvm::StringRef eventName,llvm::StringRef eventType,Node * data,unsigned index)4443 TraceEventNode *Function::createTraceEvent(llvm::StringRef eventName,
4444                                            llvm::StringRef eventType,
4445                                            Node *data, unsigned index) {
4446   std::string name = (getName() + "_" + eventName + "_instrumentation").str();
4447   return addNode(new TraceEventNode(name, data, eventName, eventType, index));
4448 }
4449 
createNonMaxSuppressionV4(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,ElemKind elTy)4450 NonMaxSuppressionNode *Function::createNonMaxSuppressionV4(
4451     llvm::StringRef name, NodeValue boxes, NodeValue scores,
4452     int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
4453     float scoreThreshold, ElemKind elTy) {
4454   // V4
4455   // Class/Score [BatchNum][BoxNum]
4456   // Boxes [BatdhNum][BoxNum][4]
4457   // Result [BatchNum*MaxOutputPerBatch]
4458   // NumberOfIndicesDetected [BatchNum*MaxOutputPerBatch]
4459   auto scoresDim = scores.dims();
4460   int scoresBoxDim = scoresDim.size() - 1;
4461   if (maxOutputBoxesPerClass == 0) {
4462     maxOutputBoxesPerClass = scoresDim[scoresBoxDim];
4463   }
4464 
4465   // Allocating maximum because we don't know how many boxes will actually be
4466   // detected.
4467   std::vector<dim_t> newDim = {static_cast<dim_t>(maxOutputBoxesPerClass)};
4468   auto indicesTy = getParent()->uniqueType(elTy, newDim);
4469   auto numberOfSelectedIndicesTy = getParent()->uniqueType(
4470       elTy, {static_cast<dim_t>(maxOutputBoxesPerClass)});
4471   return addNode(new NonMaxSuppressionNode(
4472       name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
4473       maxOutputBoxesPerClass, iouThreshold, scoreThreshold, true));
4474 }
4475 
4476 NonMaxSuppressionNode *
createNonMaxSuppressionV4(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold)4477 Function::createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes,
4478                                     NodeValue scores, int64_t centerPointBox,
4479                                     int64_t maxOutputBoxesPerClass,
4480                                     float iouThreshold, float scoreThreshold) {
4481   return createNonMaxSuppressionV4(name, boxes, scores, centerPointBox,
4482                                    maxOutputBoxesPerClass, iouThreshold,
4483                                    scoreThreshold, ElemKind::Int64ITy);
4484 }
4485 
createNonMaxSuppressionV4(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,TypeRef indicesTy,TypeRef numberOfSelectedIndicesTy)4486 NonMaxSuppressionNode *Function::createNonMaxSuppressionV4(
4487     llvm::StringRef name, NodeValue boxes, NodeValue scores,
4488     int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
4489     float scoreThreshold, TypeRef indicesTy,
4490     TypeRef numberOfSelectedIndicesTy) {
4491   assert(maxOutputBoxesPerClass > 0 && "Invalid maxOutputBoxesPerClass.");
4492 
4493   return addNode(new NonMaxSuppressionNode(
4494       name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
4495       maxOutputBoxesPerClass, iouThreshold, scoreThreshold, true));
4496 }
4497 
createNonMaxSuppressionONNX(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,ElemKind elTy)4498 NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX(
4499     llvm::StringRef name, NodeValue boxes, NodeValue scores,
4500     int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
4501     float scoreThreshold, ElemKind elTy) {
4502   // ONNX
4503   // Class/Score [BatchNum][ClassNum][BoxNum]
4504   // Box [BatchNum][BoxNum][4]
4505   // Result [BatchNum*MaxOutputPerBatch][3]
4506   auto boxesDim = boxes.dims();
4507   auto scoresDim = scores.dims();
4508   int scoresBoxDim = scoresDim.size() - 1;
4509   int scoresClassDim = scoresDim.size() - 2;
4510   int scoresBatchDim = scoresDim.size() - 3;
4511   int boxesBatchDim = boxesDim.size() - 3;
4512   if (maxOutputBoxesPerClass == 0) {
4513     maxOutputBoxesPerClass = scoresDim[scoresBoxDim];
4514   }
4515 
4516   // allocating maximum because we don't know how many boxes will actually be
4517   // detected.
4518   std::vector<dim_t> newDim = {scoresDim[scoresBatchDim] *
4519                                    scoresDim[scoresClassDim] *
4520                                    static_cast<dim_t>(maxOutputBoxesPerClass),
4521                                3};
4522   auto indicesTy = getParent()->uniqueType(elTy, newDim);
4523   auto numberOfSelectedIndicesTy = getParent()->uniqueType(
4524       elTy,
4525       {boxesDim[boxesBatchDim] * static_cast<dim_t>(maxOutputBoxesPerClass)});
4526   return addNode(new NonMaxSuppressionNode(
4527       name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
4528       maxOutputBoxesPerClass, iouThreshold, scoreThreshold, false));
4529 }
4530 
createNonMaxSuppressionONNX(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold)4531 NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX(
4532     llvm::StringRef name, NodeValue boxes, NodeValue scores,
4533     int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
4534     float scoreThreshold) {
4535   return createNonMaxSuppressionONNX(name, boxes, scores, centerPointBox,
4536                                      maxOutputBoxesPerClass, iouThreshold,
4537                                      scoreThreshold, ElemKind::Int64ITy);
4538 }
4539 
createNonMaxSuppressionONNX(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,TypeRef indicesTy)4540 NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX(
4541     llvm::StringRef name, NodeValue boxes, NodeValue scores,
4542     int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
4543     float scoreThreshold, TypeRef indicesTy) {
4544   auto boxesDim = boxes.dims();
4545   assert(maxOutputBoxesPerClass > 0 && "Invalid maxOutputBoxesPerClass.");
4546 
4547   // allocating maximum because we don't know how many boxes will actually be
4548   // detected.
4549   auto numberOfSelectedIndicesTy = getParent()->uniqueType(
4550       ElemKind::Int32ITy, {1, 1, 1,
4551                            boxesDim[boxesDim.size() - 2] *
4552                                static_cast<dim_t>(maxOutputBoxesPerClass)});
4553   return addNode(new NonMaxSuppressionNode(
4554       name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
4555       maxOutputBoxesPerClass, iouThreshold, scoreThreshold, false));
4556 }
4557 
createCosineWindow(llvm::StringRef name,dim_t length)4558 Constant *Function::createCosineWindow(llvm::StringRef name, dim_t length) {
4559   auto window = getParent()->createConstant(ElemKind::FloatTy, {length}, name);
4560   auto windowH = window->getHandle<float>();
4561   for (dim_t n = 0; n < length; n++) {
4562     windowH.raw(n) =
4563         0.5 - 0.5 * cos(2.0 * M_PI * (double)(n) / (double)(length));
4564   }
4565   return window;
4566 }
4567 
createFFTTwiddleFactors(llvm::StringRef name,dim_t fftLength)4568 Constant *Function::createFFTTwiddleFactors(llvm::StringRef name,
4569                                             dim_t fftLength) {
4570   auto twiddleFactors =
4571       getParent()->createConstant(ElemKind::FloatTy, {2 * fftLength}, name);
4572   auto twiddleFactorsH = twiddleFactors->getHandle<float>();
4573   for (dim_t k = 0; k < fftLength; k++) {
4574     twiddleFactorsH.raw(2 * k + 0) =
4575         cos(2.0 * M_PI * (double)(k) / (double)(fftLength));
4576     twiddleFactorsH.raw(2 * k + 1) =
4577         -sin(2.0 * M_PI * (double)(k) / (double)(fftLength));
4578   }
4579   return twiddleFactors;
4580 }
4581 
createFFTBitReverseIndices(llvm::StringRef name,dim_t fftLength)4582 Constant *Function::createFFTBitReverseIndices(llvm::StringRef name,
4583                                                dim_t fftLength) {
4584   assert(fftLength >= 1 && "FFT length must be at least 1!");
4585   // Local function to reverse the bits of a number.
4586   auto reverseBits = [](uint64_t bits, dim_t numBits) -> uint64_t {
4587     assert(((0 <= numBits) && (numBits <= 64)) &&
4588            "Maximum number of bits exceeded for 'reverseBits' function!");
4589     if (numBits <= 0) {
4590       return 0;
4591     }
4592     uint64_t bitsRev = 0;
4593     uint64_t bitsMask = 1;
4594     uint64_t bitsRevMask = 1 << (numBits - 1);
4595     for (dim_t idx = 0; idx < numBits; idx++) {
4596       if (bits & bitsMask) {
4597         bitsRev |= bitsRevMask;
4598       }
4599       bitsMask <<= 1;
4600       bitsRevMask >>= 1;
4601     }
4602     return bitsRev;
4603   };
4604   auto bitReverseIndices =
4605       getParent()->createConstant(ElemKind::Int32ITy, {fftLength}, name);
4606   auto bitReverseIndicesH = bitReverseIndices->getHandle<int32_t>();
4607   dim_t numBits = std::log2((double)fftLength);
4608   for (dim_t idx = 0; idx < fftLength; idx++) {
4609     bitReverseIndicesH.raw(idx) =
4610         static_cast<int32_t>(reverseBits(idx, numBits));
4611   }
4612   return bitReverseIndices;
4613 }
4614 
createFFTComplexToRealWeights(llvm::StringRef name,dim_t fftLength,dim_t outLength)4615 Constant *Function::createFFTComplexToRealWeights(llvm::StringRef name,
4616                                                   dim_t fftLength,
4617                                                   dim_t outLength) {
4618   auto complexToRealWeights =
4619       getParent()->createConstant(ElemKind::FloatTy, {2 * outLength}, name);
4620   auto complexToRealWeightsH = complexToRealWeights->getHandle<float>();
4621   for (dim_t k = 0; k < outLength; k++) {
4622     complexToRealWeightsH.raw(2 * k + 0) =
4623         0.5 * (1 - sin(2.0 * M_PI * (double)(k) / (double)(fftLength)));
4624     complexToRealWeightsH.raw(2 * k + 1) =
4625         -0.5 * cos(2.0 * M_PI * (double)(k) / (double)(fftLength));
4626   }
4627   return complexToRealWeights;
4628 }
4629 
createAudioSpectrogram(llvm::StringRef name,NodeValue input,int64_t windowSize,int64_t windowStride,bool magnitudeSquared)4630 AudioSpectrogramNode *Function::createAudioSpectrogram(llvm::StringRef name,
4631                                                        NodeValue input,
4632                                                        int64_t windowSize,
4633                                                        int64_t windowStride,
4634                                                        bool magnitudeSquared) {
4635   // Output shape will be windowCount x (fftLength / 2 + 1).
4636   dim_t inputLength = input.getType()->size();
4637   dim_t windowCount = std::floor((inputLength - windowSize) / windowStride) + 1;
4638   dim_t fftLength = 1 << (dim_t)std::ceil(std::log2((double)windowSize));
4639   auto spectrogramTy = getParent()->uniqueType(
4640       ElemKind::FloatTy, {windowCount, fftLength / 2 + 1});
4641 
4642   // Create a cosine FFT windowing function.
4643   auto window = createCosineWindow(std::string(name) + ".Window", windowSize);
4644 
4645   // Create the FFT weights for a fftLength/2 complex FFT.
4646   auto twiddleFactors = createFFTTwiddleFactors(
4647       std::string(name) + ".TwiddleFactors", fftLength / 2);
4648   auto bitReverseIndices = createFFTBitReverseIndices(
4649       std::string(name) + ".BitReverseIndices", fftLength / 2);
4650 
4651   // Create the complex to real FFT mapping coefficients.
4652   // For small FFT length make sure to generate at least 1 coefficient.
4653   auto complexToRealWeights = createFFTComplexToRealWeights(
4654       std::string(name) + ".ComplexToRealWeights", fftLength,
4655       (fftLength / 4) >= 1 ? (fftLength / 4) : 1);
4656 
4657   // Create AudioSpectrogram node.
4658   return addNode(new AudioSpectrogramNode(
4659       name, spectrogramTy, input, window, twiddleFactors, bitReverseIndices,
4660       complexToRealWeights, windowSize, windowStride, magnitudeSquared));
4661 }
4662 
createMelWeights(llvm::StringRef prefix,dim_t spectrogramLength,float sampleRate,float lowerFrequency,float upperFrequency,dim_t filterBankCount,Constant * & melWeights,Constant * & melRanges)4663 void Function::createMelWeights(llvm::StringRef prefix, dim_t spectrogramLength,
4664                                 float sampleRate, float lowerFrequency,
4665                                 float upperFrequency, dim_t filterBankCount,
4666                                 Constant *&melWeights, Constant *&melRanges) {
4667   auto fftLength = 2 * (spectrogramLength - 1);
4668   dim_t numFreqBins = fftLength / 2;
4669   dim_t numMelBins = filterBankCount;
4670 
4671   // Mel frequency scale local lambda function.
4672   auto melFreqScale = [](float freq) -> float {
4673     return 1127.0f * logf(1.0f + freq / 700.0f);
4674   };
4675 
4676   // Always exclude DC (TensorFlow implementation choice from HTK).
4677   float freqDelta = sampleRate / (float)(fftLength);
4678   dim_t freqIdxMin = (dim_t)(1.5 + (lowerFrequency / freqDelta));
4679   dim_t freqIdxMax = (dim_t)(upperFrequency / freqDelta);
4680   freqIdxMax = (freqIdxMax >= numFreqBins) ? numFreqBins : freqIdxMax;
4681 
4682   // Create Mel ranges constant.
4683   melRanges = getParent()->createConstant(ElemKind::Int32ITy, {2 * numMelBins},
4684                                           std::string(prefix) + ".MelRanges");
4685   auto melRangesH = melRanges->getHandle<int32_t>();
4686 
4687   // Mel weights and frequency start/stop (inclusive) buffers.
4688   auto melBinFreqWeights = std::make_unique<float[]>(numMelBins * numFreqBins);
4689   dim_t melBinFreqWeightsNum = 0;
4690 
4691   // Mel frequency limits.
4692   float melFreqLower = melFreqScale(lowerFrequency);
4693   float melFreqUpper = melFreqScale(upperFrequency);
4694   float melFreqDelta = (melFreqUpper - melFreqLower) / (numMelBins + 1);
4695   for (dim_t melIdx = 0; melIdx < numMelBins; melIdx++) {
4696 
4697     float melFreqLeft = melFreqLower + (melIdx + 0) * melFreqDelta;
4698     float melFreqCenter = melFreqLower + (melIdx + 1) * melFreqDelta;
4699     float melFreqRight = melFreqLower + (melIdx + 2) * melFreqDelta;
4700 
4701     int32_t freqIdxStart = -1;
4702     int32_t freqIdxStop = -2;
4703 
4704     for (dim_t freqIdx = freqIdxMin; freqIdx <= freqIdxMax; freqIdx++) {
4705       float melFreq = melFreqScale(freqIdx * freqDelta);
4706       if ((melFreqLeft < melFreq) && (melFreq < melFreqRight)) {
4707 
4708         // Compute frequency bin weight for this Mel bin.
4709         float weight = 1.0f - std::abs(melFreq - melFreqCenter) / melFreqDelta;
4710 
4711         // Store the frequency bin weight.
4712         melBinFreqWeights[melBinFreqWeightsNum++] = weight;
4713 
4714         // Update frequency bin start/stop index.
4715         if (freqIdxStart == -1) {
4716           freqIdxStart = freqIdx;
4717         }
4718         freqIdxStop = freqIdx;
4719       }
4720     }
4721 
4722     // Store the frequency bin start/stop index.
4723     melRangesH.raw(2 * melIdx + 0) = freqIdxStart;
4724     melRangesH.raw(2 * melIdx + 1) = freqIdxStop;
4725   }
4726 
4727   // Validate Mel ranges.
4728   dim_t melBinFreqWeightsNumValidate = 0;
4729   for (dim_t melIdx = 0; melIdx < numMelBins; melIdx++) {
4730     int32_t freqIdxRange =
4731         melRangesH.raw(2 * melIdx + 1) - melRangesH.raw(2 * melIdx + 0) + 1;
4732     melBinFreqWeightsNumValidate += freqIdxRange;
4733   }
4734   assert(melBinFreqWeightsNum == melBinFreqWeightsNumValidate &&
4735          "Invalid Mel ranges");
4736 
4737   // Create Mel weights constant.
4738   melWeights =
4739       getParent()->createConstant(ElemKind::FloatTy, {melBinFreqWeightsNum},
4740                                   std::string(prefix) + ".MelWeights");
4741   auto melWeightsH = melWeights->getHandle<float>();
4742   for (dim_t idx = 0; idx < melBinFreqWeightsNum; idx++) {
4743     melWeightsH.raw(idx) = melBinFreqWeights[idx];
4744   }
4745 }
4746 
createDCTMat(llvm::StringRef name,dim_t N,dim_t K)4747 Constant *Function::createDCTMat(llvm::StringRef name, dim_t N, dim_t K) {
4748   Constant *dctMat =
4749       getParent()->createConstant(ElemKind::FloatTy, {K, N}, name);
4750   auto dctMatH = dctMat->getHandle<float>();
4751   float dctFact = (float)sqrt(2.0 / (double)(N));
4752   for (dim_t k = 0; k < K; k++) {
4753     for (dim_t n = 0; n < N; n++) {
4754       dctMatH.at({k, n}) =
4755           dctFact * cos(M_PI / (double)(N) * ((double)(n) + 0.5) * (double)(k));
4756     }
4757   }
4758   return dctMat;
4759 }
4760 
createMFCC(llvm::StringRef name,NodeValue spectrogram,float sampleRate,float lowerFrequency,float upperFrequency,int64_t filterBankCount,int64_t numCoefficients)4761 MFCCNode *Function::createMFCC(llvm::StringRef name, NodeValue spectrogram,
4762                                float sampleRate, float lowerFrequency,
4763                                float upperFrequency, int64_t filterBankCount,
4764                                int64_t numCoefficients) {
4765   // Create the Mel weights.
4766   dim_t spectrogramLength = spectrogram.dims()[1];
4767   Constant *melWeights;
4768   Constant *melRanges;
4769   createMelWeights(name, spectrogramLength, sampleRate, lowerFrequency,
4770                    upperFrequency, filterBankCount, melWeights, melRanges);
4771 
4772   // Create the DCT transform matrix.
4773   Constant *dctMat = createDCTMat(std::string(name) + ".DCTMat",
4774                                   filterBankCount, numCoefficients);
4775 
4776   // Output shape will be windowCount x numCoefficients.
4777   dim_t windowCount = spectrogram.dims()[0];
4778   auto coefficientsTy = getParent()->uniqueType(
4779       ElemKind::FloatTy, {windowCount, static_cast<dim_t>(numCoefficients)});
4780 
4781   // Create MFCC node.
4782   return addNode(new MFCCNode(name, coefficientsTy, spectrogram, melWeights,
4783                               melRanges, dctMat, sampleRate, lowerFrequency,
4784                               upperFrequency, filterBankCount,
4785                               numCoefficients));
4786 }
4787 
4788 //===----------------------------------------------------------------------===//
4789 //                   Graph dumping and printing
4790 //===----------------------------------------------------------------------===//
4791 
dump() const4792 void Function::dump() const {
4793   llvm::outs() << "Graph structure " << getName() << ":\n";
4794   for (auto &n : nodes_) {
4795     llvm::outs() << n.getDebugDesc();
4796   }
4797 }
4798 
toString(bool skipUsersForStorage,bool skipName) const4799 std::string Function::toString(bool skipUsersForStorage, bool skipName) const {
4800   std::string storage;
4801   llvm::raw_string_ostream os(storage);
4802   dump(os, skipUsersForStorage, skipName);
4803   return os.str();
4804 }
4805 
getHash() const4806 llvm::hash_code Function::getHash() const {
4807   // Omit function name when generating the hash.
4808   return llvm::hash_value(toString(/* skipUsersForStorage */ false,
4809                                    /* skipName */ true));
4810 }
4811 
dump(llvm::raw_ostream & os,bool skipUsersForStorage,bool skipName) const4812 void Function::dump(llvm::raw_ostream &os, bool skipUsersForStorage,
4813                     bool skipName) const {
4814   os << "Graph structure";
4815   if (!skipName) {
4816     os << " " << getName();
4817   }
4818   os << ":\n";
4819   std::set<const Node *, SortNamed> sorted;
4820   for (const Node &n : nodes_) {
4821     sorted.insert(&n);
4822   }
4823   for (auto *n : sorted) {
4824     os << n->getDebugDesc();
4825   }
4826   for (auto *C : getNamedSorted(findConstants())) {
4827     os << C->getDebugDesc(skipUsersForStorage);
4828   }
4829   for (auto *P : getNamedSorted(findPlaceholders())) {
4830     os << P->getDebugDesc(skipUsersForStorage);
4831   }
4832 }
4833 
4834 /// We can't use NodeWalker here, because it ignores result indices, which
4835 /// are critical in generating detailed debug output.
4836 class FunctionDottyPrinter : public AbstractDottyPrinter {
4837   // A set of already visited (during graph walk) nodes.
4838   std::unordered_set<Node *> visitedNodes_{};
4839 
4840   /// Recursively traverses inputs of node \p N using Deep First Search.
4841   /// Each node will be visited no more than once. The method also dumps
4842   /// edges with their port identifiers in dotty format.
visitNode(Node * N)4843   void visitNode(Node *N) {
4844     if (visitedNodes_.find(N) != visitedNodes_.end())
4845       return;
4846     visitedNodes_.insert(N);
4847 
4848     dumpNode(N, false);
4849 
4850     // Print edges for the predicate field, if it's used.
4851     if (N->hasPredicate()) {
4852       auto pred = N->getPredicate();
4853       size_t resNo = pred.getResNo();
4854       std::ostringstream edge;
4855       edge << pred.getNode()->getName().str() << ":"
4856            << pred.getNode()->getOutputName(resNo).str() << " -> "
4857            << N->getName().str() << ":w";
4858       dumpEdgeStyle(N, 0, pred, edge);
4859       edges_.insert(edge.str());
4860       visitNode(pred);
4861     }
4862 
4863     for (size_t i = 0; i < N->getNumInputs(); i++) {
4864       Node *to = N->getNthInput(i).getNode();
4865       size_t resNo = N->getNthInput(i).getResNo();
4866 
4867       std::ostringstream edge;
4868       edge << to->getName().str() << ":" << to->getOutputName(resNo).str()
4869            << " -> " << N->getName().str() << ":" << N->getInputName(i);
4870       dumpEdgeStyle(N, i, to, edge);
4871       edges_.insert(edge.str());
4872 
4873       visitNode(to);
4874     }
4875   }
4876 
4877 public:
visitGraph(Function * F)4878   void visitGraph(Function *F) {
4879     for (auto &N : F->getNodes()) {
4880       visitNode(&N);
4881     }
4882   }
4883 };
4884 
dumpDAG()4885 std::string Function::dumpDAG() {
4886   llvm::SmallString<64> dotPath;
4887   llvm::sys::fs::createTemporaryFile("dotty_graph_dump", "dot", dotPath);
4888   dumpDAG(dotPath);
4889 
4890   return std::string(dotPath.begin(), dotPath.end());
4891 }
4892 
dumpDAG(llvm::StringRef dotFilename)4893 void Function::dumpDAG(llvm::StringRef dotFilename) {
4894   llvm::outs() << "Writing dotty graph for Function to: " << dotFilename
4895                << '\n';
4896 
4897   FunctionDottyPrinter DP;
4898 
4899   DP.visitGraph(this);
4900 
4901   std::ofstream myfile;
4902   myfile.open(dotFilename);
4903   DP.dumpAll(myfile);
4904   myfile.close();
4905 }
4906 
dumpDAG(const char * dotFilename)4907 void Function::dumpDAG(const char *dotFilename) {
4908   dumpDAG(llvm::StringRef(dotFilename));
4909 }
4910 
getNodeByName(llvm::StringRef name)4911 Node *Function::getNodeByName(llvm::StringRef name) {
4912   for (auto &N : getNodes()) {
4913     if (N.getName().equals(name)) {
4914       return &N;
4915     }
4916   }
4917   return nullptr;
4918 }
4919 
getNodeValueByName(llvm::StringRef name)4920 NodeValue Function::getNodeValueByName(llvm::StringRef name) {
4921   auto strPair = name.split(':');
4922   // Search node, constant or placeholder.
4923   auto nodeName = strPair.first;
4924   Node *node = getNodeByName(nodeName);
4925   node = node ? node : getParent()->getConstantByName(nodeName);
4926   node = node ? node : getParent()->getPlaceholderByNameSlow(nodeName);
4927   if (!node || (node->getNumResults() == 0)) {
4928     return NodeValue();
4929   }
4930   // Get result number.
4931   if (node->getNumResults() == 1) {
4932     return NodeValue(node);
4933   } else {
4934     unsigned resNo = 0;
4935     CHECK(!strPair.second.getAsInteger(0, resNo)) << "Invalid node value name!";
4936     return NodeValue(node, resNo);
4937   }
4938 }
4939 
eraseConstant(ConstList::iterator I)4940 void Module::eraseConstant(ConstList::iterator I) {
4941   if (I == constants_.end())
4942     return;
4943   logStorageDeletion(functions_, *I);
4944   delete *I;
4945   constants_.erase(I);
4946 }
4947 
erasePlaceholder(PlaceholderList::iterator I)4948 void Module::erasePlaceholder(PlaceholderList::iterator I) {
4949   if (I == placeholders_.end()) {
4950     return;
4951   }
4952 
4953   logStorageDeletion(functions_, *I);
4954   delete *I;
4955   placeholders_.erase(I);
4956 }
4957 
eraseNode(NodesList::iterator I)4958 void Function::eraseNode(NodesList::iterator I) {
4959   // Log node deletion.
4960   logCtx_->logNodeDeletion(*I);
4961 
4962   nodes_.erase(I);
4963 }
4964 
getConstantByName(llvm::StringRef name) const4965 Constant *Module::getConstantByName(llvm::StringRef name) const {
4966   for (auto *V : getConstants()) {
4967     if (V->getName() == name)
4968       return V;
4969   }
4970   return nullptr;
4971 }
4972 
randomizeConstants(const std::map<Kinded::Kind,std::set<unsigned>> & ignoredConstants)4973 void Function::randomizeConstants(
4974     const std::map<Kinded::Kind, std::set<unsigned>> &ignoredConstants) {
4975   for (Constant *c : getParent()->getConstants()) {
4976     bool usedHere = false;
4977     bool usedElsewhere = false;
4978     bool ignored = false;
4979 
4980     for (auto &user : c->getUsers()) {
4981       auto *nodeUser = user.getUser();
4982       if (nodeUser->getParent() == this) {
4983         usedHere = true;
4984       } else {
4985         usedElsewhere = true;
4986       }
4987 
4988       auto kind = nodeUser->getKind();
4989       if (ignoredConstants.count(kind)) {
4990         for (auto idx : ignoredConstants.at(kind)) {
4991           if (nodeUser->getNthInput(idx).getNode() == c) {
4992             ignored = true;
4993             break;
4994           }
4995         }
4996       }
4997     }
4998 
4999     if (!usedHere) {
5000       continue;
5001     }
5002 
5003     if (usedElsewhere) {
5004       LOG(FATAL) << "Can't randomize Constant \"" << c->getName().str()
5005                  << "\" because it is used by another function";
5006     }
5007 
5008     if (ignored) {
5009       continue;
5010     }
5011 
5012     auto &payload = c->getPayloadMutable();
5013 
5014     switch (c->getElementType()) {
5015     case ElemKind::FloatTy: {
5016       auto H = payload.getHandle<float>();
5017       auto minMaxArg = H.minMaxArg();
5018       H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5019       break;
5020     }
5021     case ElemKind::Float16Ty: {
5022       auto H = payload.getHandle<float16_t>();
5023       auto minMaxArg = H.minMaxArg();
5024       H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5025       break;
5026     }
5027     case ElemKind::BFloat16Ty: {
5028       auto H = payload.getHandle<bfloat16_t>();
5029       auto minMaxArg = H.minMaxArg();
5030       H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5031       break;
5032     }
5033     case ElemKind::Int8QTy: {
5034       auto H = payload.getHandle<int8_t>();
5035       auto minMaxArg = H.minMaxArg();
5036       H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5037       break;
5038     }
5039     case ElemKind::UInt8QTy: {
5040       auto H = payload.getHandle<uint8_t>();
5041       auto minMaxArg = H.minMaxArg();
5042       H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5043       break;
5044     }
5045     case ElemKind::Int16QTy: {
5046       auto H = payload.getHandle<int16_t>();
5047       auto minMaxArg = H.minMaxArg();
5048       H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5049       break;
5050     }
5051     case ElemKind::Int32QTy: {
5052       auto H = payload.getHandle<int32_t>();
5053       auto minMaxArg = H.minMaxArg();
5054       H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5055       break;
5056     }
5057     case ElemKind::Int32ITy: {
5058       auto H = payload.getHandle<int32_t>();
5059       auto minMaxArg = H.minMaxArg();
5060       H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5061       break;
5062     }
5063     case ElemKind::Int64ITy: {
5064       auto H = payload.getHandle<int64_t>();
5065       auto minMaxArg = H.minMaxArg();
5066       H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5067       break;
5068     }
5069     case ElemKind::UInt8FusedQTy:
5070       payload.getHandle<uint8_t>().randomize(
5071           std::numeric_limits<uint8_t>::lowest(),
5072           std::numeric_limits<uint8_t>::max(), getPRNG());
5073       break;
5074     case ElemKind::UInt8FusedFP16QTy:
5075       payload.getHandle<uint8_t>().randomize(
5076           std::numeric_limits<uint8_t>::lowest(),
5077           std::numeric_limits<uint8_t>::max(), getPRNG());
5078       break;
5079     case ElemKind::UInt4FusedFP16QTy:
5080       payload.getHandle<uint8_t>().randomize(
5081           std::numeric_limits<uint8_t>::lowest(),
5082           std::numeric_limits<uint8_t>::max(), getPRNG());
5083       break;
5084     case ElemKind::BoolTy:
5085       payload.getHandle<bool>().randomize(false, true, getPRNG());
5086       break;
5087     default:
5088       LOG(FATAL) << "Unsupported ElemKind";
5089     }
5090   }
5091 }
5092 
getPlaceholderByNameSlow(llvm::StringRef name) const5093 Placeholder *Module::getPlaceholderByNameSlow(llvm::StringRef name) const {
5094   for (auto *P : getPlaceholders()) {
5095     if (P->getName() == name) {
5096       return P;
5097     }
5098   }
5099 
5100   return nullptr;
5101 }
5102 
eraseConstant(Constant * N)5103 void Module::eraseConstant(Constant *N) {
5104   auto &vars = getConstants();
5105   auto I = std::find(vars.begin(), vars.end(), N);
5106   eraseConstant(I);
5107 }
5108 
eraseNode(Node * N)5109 void Function::eraseNode(Node *N) {
5110   if (Constant *V = dyn_cast<Constant>(N)) {
5111     return getParent()->eraseConstant(V);
5112   }
5113   assert(std::find_if(nodes_.begin(), nodes_.end(),
5114                       [N](const Node &node) -> bool { return &node == N; }) !=
5115              nodes_.end() &&
5116          "Could not find node to delete!");
5117   eraseNode(N->getIterator());
5118 }
5119 
findPlaceholders()5120 PlaceholderList Function::findPlaceholders() {
5121   PlaceholderList list;
5122   for (auto &PH : parent_->getPlaceholders()) {
5123     for (auto &user : PH->getUsers()) {
5124       if (user.getUser()->getParent() == this) {
5125         list.push_back(PH);
5126         break;
5127       }
5128     }
5129   }
5130   return list;
5131 }
5132 
findPlaceholders() const5133 PlaceholderList Function::findPlaceholders() const {
5134   PlaceholderList list;
5135   for (auto &PH : parent_->getPlaceholders()) {
5136     for (auto &user : PH->getUsers()) {
5137       if (user.getUser()->getParent() == this) {
5138         list.push_back(PH);
5139         break;
5140       }
5141     }
5142   }
5143   return list;
5144 }
5145 
findConstants()5146 ConstList Function::findConstants() {
5147   ConstList list;
5148   for (auto &constant : parent_->getConstants()) {
5149     for (auto &user : constant->getUsers()) {
5150       if (user.getUser()->getParent() == this) {
5151         list.push_back(constant);
5152         break;
5153       }
5154     }
5155   }
5156   return list;
5157 }
5158 
findConstants() const5159 ConstList Function::findConstants() const {
5160   ConstList list;
5161   for (auto &constant : parent_->getConstants()) {
5162     for (auto &user : constant->getUsers()) {
5163       if (user.getUser()->getParent() == this) {
5164         list.push_back(constant);
5165         break;
5166       }
5167     }
5168   }
5169   return list;
5170 }
5171 
clone(llvm::StringRef newName,llvm::DenseMap<const Node *,Node * > * map,llvm::DenseMap<const Node *,Node * > * currToNewMap)5172 Function *Function::clone(llvm::StringRef newName,
5173                           llvm::DenseMap<const Node *, Node *> *map,
5174                           llvm::DenseMap<const Node *, Node *> *currToNewMap) {
5175   Module *M = getParent();
5176   auto *newF = M->createFunction(newName);
5177   return clone(newF, map, currToNewMap);
5178 }
5179 
5180 Function *
clone(Function * newF,llvm::DenseMap<const Node *,Node * > * map,llvm::DenseMap<const Node *,Node * > * currToNewMap) const5181 Function::clone(Function *newF, llvm::DenseMap<const Node *, Node *> *map,
5182                 llvm::DenseMap<const Node *, Node *> *currToNewMap) const {
5183   // Maps current nodes to new nodes.
5184   llvm::DenseMap<const Node *, Node *> currToNew;
5185 
5186   // Initialize the map from a user-provided map.
5187   if (currToNewMap) {
5188     currToNew.insert(currToNewMap->begin(), currToNewMap->end());
5189   }
5190 
5191   // Clone all of the nodes in the function.
5192   for (auto &N : getNodes()) {
5193     Node *copy = N.clone();
5194     // Record the copy relationship between the graphs.
5195     currToNew[&N] = copy;
5196     newF->addNode(copy);
5197     if (N.hasPredicate()) {
5198       copy->setPredicate(N.getPredicate());
5199     }
5200   }
5201 
5202   // At this point we have a new invalid function that points into nodes in
5203   // the original function. Here we update the links between the nodes in the
5204   // new function.
5205   for (auto &N : newF->getNodes()) {
5206     // Fix each one of the inputs of this node.
5207     for (unsigned inp = 0, e = N.getNumInputs(); inp < e; inp++) {
5208       auto input = N.getNthInput(inp);
5209 
5210       auto it = currToNew.find(input.getNode());
5211       if (it == currToNew.end()) {
5212         assert(isa<Storage>(input.getNode()) &&
5213                "Could not find a mapping for some node!");
5214         continue;
5215       }
5216 
5217       // Update the node with the edge to the current graph.
5218       N.setNthInput(inp, NodeValue(it->second, input.getResNo()));
5219     }
5220 
5221     if (N.hasPredicate()) {
5222       auto it = currToNew.find(N.getPredicate().getNode());
5223       if (it != currToNew.end()) {
5224         N.setPredicate(NodeValue(it->second, N.getPredicate().getResNo()));
5225       }
5226     }
5227   }
5228 
5229   // Record the node mapping into the external map.
5230   if (map) {
5231     assert(map->empty() && "The external map must be empty");
5232     for (auto it : currToNew) {
5233       map->insert(it);
5234     }
5235   }
5236 
5237   assert(newF->getNodes().size() == getNodes().size() && "Invalid func size");
5238   return newF;
5239 }
5240 
5241 /// Verify the input \p idx of a node \p N. Check that the node \p N is in the
5242 /// use-list of the corresponding input node.
verifyNodeInput(const Node & N,size_t idx)5243 static bool verifyNodeInput(const Node &N, size_t idx) {
5244   auto input = N.getNthInput(idx);
5245   auto *refN = input.getNode();
5246   // Check that N is in the use-list of the input node and there is a proper
5247   // entry for it.
5248   for (auto &U : refN->getUsers()) {
5249     if (U.getUser() == &N && *U.get() == input) {
5250       return true;
5251     }
5252   }
5253 
5254   report("Any node referencing another node N must be in the use-list of the "
5255          "node N");
5256   return false;
5257 }
5258 
clone() const5259 Module *Module::clone() const {
5260   auto *M = new Module;
5261   return clone(M);
5262 }
5263 
clone(Module * M) const5264 Module *Module::clone(Module *M) const {
5265   // Maps current nodes to new nodes.
5266   llvm::DenseMap<const Node *, Node *> currToNew;
5267   // Clone placeholders.
5268   for (auto &PH : getPlaceholders()) {
5269     auto *copyPH = M->createPlaceholder(PH->getType(), PH->getName(),
5270                                         PH->isTraining(), PH->getLayout());
5271     currToNew[PH] = copyPH;
5272   }
5273   // Clone constants.
5274   for (auto &C : getConstants()) {
5275     // Cloner cannot decide on its own what to do with constants having unowned
5276     // payloads. Some kind of policy/hook maybe needed in the future for
5277     // deciding what needs to be done in such cases.
5278     DCHECK(!C->getPayload().isUnowned())
5279         << "Cannot copy constant " << C->getName().str()
5280         << ": Unowned payloads are not supported";
5281     auto *copyC = M->createConstant(C->getType(), C->getName(), C->getLayout());
5282     copyC->assign(&C->getPayload());
5283     currToNew[C] = copyC;
5284   }
5285   // Clone all functions.
5286   for (auto *F : getFunctions()) {
5287     // Create an empty clone function in the new module.
5288     auto *copyF = M->createFunction(F->getName());
5289     // Clone function's body into the newly created cloned function. Use the
5290     // currToNew to properly map constants and placeholders.
5291     F->clone(copyF, nullptr, &currToNew);
5292     // Update all types by cloned types.
5293     for (auto &N : copyF->getNodes()) {
5294       for (unsigned idx = 0, e = N.getNumResults(); idx < e; ++idx) {
5295         N.setType(idx, M->uniqueType(*N.getType(idx)));
5296       }
5297     }
5298   }
5299   return M;
5300 }
5301 
5302 /// \returns True if \p n is a storage node (constant or placeholder) of the
5303 /// function \p F.
isGraphStorageNode(Node * n,const Function * F)5304 static bool isGraphStorageNode(Node *n, const Function *F) {
5305   auto &vars = F->getParent()->getConstants();
5306   auto &placeholders = F->getParent()->getPlaceholders();
5307 
5308   if (Constant *V = dyn_cast<Constant>(n)) {
5309     return std::find(vars.begin(), vars.end(), V) != vars.end();
5310   }
5311 
5312   if (Placeholder *P = dyn_cast<Placeholder>(n)) {
5313     return std::find(placeholders.begin(), placeholders.end(), P) !=
5314            placeholders.end();
5315   }
5316 
5317   return false;
5318 }
5319 
5320 /// Insert \p node in \p nameToNode and report an error if the insertion fails.
5321 /// \returns True if \p node was inserted into \p nameToNode. False otherwise.
5322 /// When true is returned that means that \p nameToNode had no other nodes
5323 /// registered under \p node.getName().
5324 static bool
insertAndReport(std::unordered_map<std::string,const Node * > & nameToNode,const Node & node,const Function & function)5325 insertAndReport(std::unordered_map<std::string, const Node *> &nameToNode,
5326                 const Node &node, const Function &function) {
5327   bool inserted = expectCompareTrue(
5328       "Node is not unique", nameToNode.insert({node.getName(), &node}).second,
5329       true, &function);
5330   if (!inserted) {
5331     std::string storage;
5332     llvm::raw_string_ostream msg(storage);
5333     /// Output extra information helping to find the error.
5334     msg << "The node with name '" << node.getName()
5335         << "' conflicts with a previous definition:\n";
5336     msg << "Current definition: " << node.getDebugDesc() << "\n";
5337     msg << "Previous definition: "
5338         << nameToNode[node.getName()]->getDebugDesc();
5339     report(msg.str().c_str());
5340     return false;
5341   }
5342   return true;
5343 }
5344 
verify(const Backend * backend) const5345 bool Function::verify(const Backend *backend) const {
5346   bool isValid = true;
5347   if (backend) {
5348     if (backend->getTensorLayoutRequirements().isEnabled()) {
5349       isValid &= expectCompareTrue(
5350           "Expected correct backend-specific layouts for the graph",
5351           verifyLayouts(*this, backend->getTensorLayoutRequirements()), true,
5352           this);
5353     }
5354   } else {
5355     // Always run verification pre-lowering / when we don't have backend:
5356     isValid &= expectCompareTrue(
5357         "Expected correct Glow canonical layouts for the graph",
5358         verifyLayouts(*this, CanonicalTensorLayout::getInstance()), true, this);
5359   }
5360   std::unordered_map<std::string, const Node *> nameToNode;
5361 
5362   for (auto *V : findConstants()) {
5363     isValid &= insertAndReport(nameToNode, *V, *this);
5364     isValid &= expectCompareTrue("Constant and its payload must have same type",
5365                                  *V->getType(), V->getPayload().getType(), V);
5366   }
5367 
5368   nameToNode.clear();
5369   for (const auto &N : nodes_) {
5370     isValid &= insertAndReport(nameToNode, N, *this);
5371   }
5372 
5373   // Any node referenced by one of the graph nodes should be part of the
5374   // Graph.
5375   for (const auto &N : nodes_) {
5376     for (size_t idx = 0, e = N.getNumInputs(); idx < e; ++idx) {
5377       auto &input = N.getNthInput(idx);
5378       // Verify each input of N.
5379       isValid &= verifyNodeInput(N, idx);
5380       bool foundNode =
5381           std::find(nodes_.begin(), nodes_.end(), *input) != nodes_.end();
5382       isValid &= expectCompareTrue(
5383           "Every node referenced by one of the graph nodes should be part of "
5384           "the graph",
5385           foundNode || isGraphStorageNode(input, this), true, &N);
5386     }
5387   }
5388 
5389   // Check that all uses of each node refer to this node.
5390   for (const auto &N : nodes_) {
5391     for (const auto &U : N.getUsers()) {
5392       isValid &= expectCompareTrue<const Node *>(
5393           "All uses of a node should refer to this node", U.get()->getNode(),
5394           &N, &N);
5395       ;
5396     }
5397   }
5398 
5399   // Check that all types used by nodes belong to the parent module.
5400   auto &types = getParent()->getTypes();
5401   for (const auto &N : nodes_) {
5402     for (size_t idx = 0, e = N.getNumResults(); idx < e; ++idx) {
5403       auto ty = N.getType(idx);
5404       bool foundType =
5405           std::find(types.begin(), types.end(), *ty) != types.end();
5406       isValid &= expectCompareTrue(
5407           "Every type used by one of the graph nodes should be part of "
5408           "the graph",
5409           foundType, true, &N);
5410     }
5411   }
5412 
5413   std::unordered_map<const Placeholder *, const Node *> placeholderWrittenTo;
5414   for (const auto &N : nodes_) {
5415     isValid &=
5416         expectCompareTrue("Node is not linked to the function it belongs",
5417                           N.getParent(), this, &N);
5418     isValid &= N.verify();
5419     // Make sure all the placeholders are at most written once, and that
5420     // constants are never written to.
5421     for (size_t idx = 0, e = N.getNumInputs(); idx < e; ++idx) {
5422       // Placeholders and Constants have no input, so they can only be
5423       // written to via overwritten inputs.
5424       if (!N.isOverwrittenNthInput(idx)) {
5425         continue;
5426       }
5427 
5428       const Node *nthInputNode = N.getNthInput(idx).getNode();
5429       isValid &= expectCompareTrue(
5430           "Constants can never be used as an overwritten input",
5431           isa<Constant>(nthInputNode), false, nthInputNode);
5432 
5433       // Unlike Constants, Placeholders can be used at most once as
5434       // overwritten inputs. Keep a map of Placeholders to Nodes that used
5435       // them as overwritten inputs, which is also used later to check for
5436       // read-after-write dependence violations.
5437       const auto *ph = dyn_cast<Placeholder>(nthInputNode);
5438       if (!ph) {
5439         continue;
5440       }
5441       auto varToFirstDef = placeholderWrittenTo.find(ph);
5442       bool writtenOnce = expectCompareTrue(
5443           "Placeholder has more than one write",
5444           varToFirstDef == placeholderWrittenTo.end(), true, ph);
5445       if (!writtenOnce) {
5446         isValid = false;
5447         std::string storage;
5448         llvm::raw_string_ostream msg(storage);
5449 
5450         msg << "Placeholder " << ph->getDebugDesc() << '\n';
5451         msg << "has more than one write; second writer found:\n";
5452         msg << N.getDebugDesc() << '\n';
5453         msg << varToFirstDef->second->getDebugDesc() << '\n';
5454 
5455         report(msg.str().c_str());
5456       }
5457 
5458       placeholderWrittenTo[ph] = &N;
5459     }
5460   }
5461 
5462   // Now check that the placeholders that are written to are either:
5463   // - Written by a save node, or
5464   // - Are only used by the node that writes them
5465   // If this check fails, that means we have implicit memory
5466   // dependencies that may not be honored by the scheduler.
5467   // Either the input IR is incorrect or the scheduler needs
5468   // fixing.
5469   for (const std::pair<const Placeholder *, const Node *> &varToWrite :
5470        placeholderWrittenTo) {
5471     if (isa<SaveNode>(varToWrite.second)) {
5472       continue;
5473     }
5474     for (const NodeUse &use : varToWrite.first->getUsers()) {
5475       const Node *user = use.getUser();
5476       // Ignore users outside this function.
5477       if (user->getParent() != this) {
5478         continue;
5479       }
5480       isValid &= expectCompareTrue(
5481           "Implicit read after write memory dependency may not be honored",
5482           user, varToWrite.second, user);
5483     }
5484   }
5485   return isValid;
5486 }
5487 
getOutputSave(Function * F,Placeholder * PH)5488 SaveNode *glow::getOutputSave(Function *F, Placeholder *PH) {
5489   // if parent is set for PH, check if it is the same as provided Function.
5490   auto *PHP = PH->getParent();
5491   if (PHP != nullptr && F != PHP) {
5492     return nullptr;
5493   }
5494   for (auto &use : PH->getUsers()) {
5495     if (auto *save = llvm::dyn_cast<SaveNode>(use.getUser())) {
5496       if (save->getParent() == F && save->getPlaceholder() == PH) {
5497         return save;
5498       }
5499     }
5500   }
5501   return nullptr;
5502 }
5503 
recursiveClone(Function * newF,Node * node,NodeMap & currToNew)5504 Node *glow::recursiveClone(Function *newF, Node *node, NodeMap &currToNew) {
5505   Node *copy = node->clone();
5506   currToNew[node] = copy;
5507   newF->addNode(copy);
5508   for (unsigned inp = 0, e = copy->getNumInputs(); inp < e; inp++) {
5509     auto input = copy->getNthInput(inp);
5510     auto it = currToNew.find(input.getNode());
5511     Node *newInput;
5512     if (it != currToNew.end()) {
5513       newInput = it->second;
5514     } else if (llvm::isa<Storage>(input.getNode())) {
5515       continue;
5516     } else {
5517       newInput = recursiveClone(newF, input.getNode(), currToNew);
5518     }
5519     copy->setNthInput(inp, NodeValue(newInput, input.getResNo()));
5520   }
5521   return copy;
5522 }
5523 
5524 namespace glow {
5525 /// If \p PH is an output placeholder, \returns true.
5526 /// This is determined by checking if the PH has a user which uses the PH as an
5527 /// overwritten input.
isOutput(const Placeholder * PH,const Function & F)5528 bool isOutput(const Placeholder *PH, const Function &F) {
5529   for (const auto &use : PH->getUsers()) {
5530     // Look through the inputs of the PH's users. If an input is overwritten
5531     // check if it's the PH, if it is return true.
5532     auto *user = use.getUser();
5533     // Consider only users inside the same function.
5534     if (user->getParent() != &F) {
5535       continue;
5536     }
5537     for (unsigned i = 0, numInputs = user->getNumInputs(); i < numInputs; i++) {
5538       // If the input is not overwritten we can continue.
5539       if (!user->isOverwrittenNthInput(i)) {
5540         continue;
5541       }
5542       auto input = use.getUser()->getNthInput(i);
5543       if (input.getNode() == PH) {
5544         return true;
5545       }
5546     }
5547   }
5548   return false;
5549 }
5550 
5551 /// If \p PH is an input placeholder, \returns true.
isInput(const Placeholder * PH,const Function & F)5552 bool isInput(const Placeholder *PH, const Function &F) {
5553   // Check that the PH is the input to a saveNode or is used by a non saveNode.
5554   for (const auto &use : PH->getUsers()) {
5555     // Consider only users inside the same function.
5556     if (use.getUser()->getParent() != &F) {
5557       continue;
5558     }
5559     // Check if PH is an input to a saveNode.
5560     if (auto *save = dyn_cast<SaveNode>(use.getUser())) {
5561       auto input = save->getInput();
5562       // If the PH is not an input to the saveNode we keep looking.
5563       if (input.getNode() != PH) {
5564         continue;
5565       }
5566     }
5567     return true;
5568   }
5569   return false;
5570 }
5571 
operator <<(llvm::raw_ostream & os,const Module & mod)5572 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module &mod) {
5573   mod.dump(os);
5574   return os;
5575 }
5576 
operator <<(llvm::raw_ostream & os,const Module * mod)5577 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module *mod) {
5578   assert(mod != nullptr && "Null Pointer.");
5579   mod->dump(os);
5580   return os;
5581 }
5582 
operator <<(llvm::raw_ostream & os,const Function & F)5583 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function &F) {
5584   F.dump(os);
5585   return os;
5586 }
5587 
operator <<(llvm::raw_ostream & os,const Function * F)5588 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function *F) {
5589   assert(F != nullptr && "Null Pointer.");
5590   F->dump(os);
5591   return os;
5592 }
5593 
isConvolutionSameAsFullyConnected(const ConvolutionNode * node,bool enforceInput1x1)5594 bool isConvolutionSameAsFullyConnected(const ConvolutionNode *node,
5595                                        bool enforceInput1x1) {
5596   bool isConv2D = (node->getInput().getType()->dims().size() == 4);
5597   if (!(isConv2D && node->getLayout() == ConvolutionLayout::NHWC &&
5598         !node->hasFusedActivation())) {
5599     return false;
5600   }
5601   auto filterDims = ShapeNHWC(node->getFilter().getType()->dims());
5602   ShapeHW kernels = ShapeHW(node->getKernels());
5603   ShapeHW strides = ShapeHW(node->getStrides());
5604   PaddingTLBR pads = PaddingTLBR(node->getPads());
5605   auto group = node->getGroup();
5606   auto dilation = node->getDilation();
5607 
5608   bool isSame = (filterDims.h == 1) && (filterDims.w == 1);
5609   isSame &= (kernels.height == 1) && (kernels.width == 1);
5610   isSame &= (strides.height == 1) && (strides.width == 1);
5611   isSame &= (pads.top == 0) && (pads.left == 0) && (pads.bottom == 0) &&
5612             (pads.right == 0);
5613   isSame &= (group == 1);
5614   isSame &= (dilation == 1);
5615   if (enforceInput1x1) {
5616     auto inputDims = ShapeNHWC(node->getInput().getType()->dims());
5617     isSame &= (inputDims.h == 1) && (inputDims.w == 1);
5618   }
5619   return isSame;
5620 }
5621 
isGemmSameAsFullyConnected(const GemmNode * node)5622 bool isGemmSameAsFullyConnected(const GemmNode *node) {
5623   NodeValue inpC = node->getC();
5624   return (node->getAlpha() == 1.0) && (node->getBeta() == 1.0) &&
5625          (inpC.getNode()) && (inpC.dims().size() == 1);
5626 }
5627 
5628 } // namespace glow
5629