1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "glow/Graph/Graph.h"
17 #include "glow/Backend/Backend.h"
18 #include "glow/Graph/Nodes.h"
19 #include "glow/Graph/PlaceholderBindings.h"
20 #include "glow/Graph/TensorLayout.h"
21 #include "glow/Graph/VerifierHelper.h"
22 #include "glow/Quantization/Base/Base.h"
23 #include "glow/Support/Support.h"
24
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/FileSystem.h"
30 #include "llvm/Support/Format.h"
31 #include "llvm/Support/raw_ostream.h"
32
33 #ifdef WIN32
34 #include <corecrt_math_defines.h>
35 #endif
36 #include <float.h>
37 #include <fstream>
38 #include <unordered_set>
39
40 using namespace glow;
41 using llvm::cast;
42 using llvm::dyn_cast;
43 using llvm::isa;
44
45 namespace {
46 /// A helper function to log the deletion of constant/placeholder \p s of a
47 /// module into the log context of given functions \p functions.
48 /// Note: The reason we don't log the deletion of constants in the function that
49 /// ueses or creates it, is that constants/placeholders do not have a function
50 /// parent (we can't utilize its user's function also because its users might be
51 /// removed) such that it's best to log the constants/placeholders in a Module
52 /// level log context and copy over to its all functions.
logStorageDeletion(std::list<Function * > functions,Storage * s)53 void logStorageDeletion(std::list<Function *> functions, Storage *s) {
54 for (auto *F : functions) {
55 F->getLogContext()->logNodeDeletion(*s);
56 }
57 if (functions.size() > 0) {
58 auto *F = *(functions.begin());
59 F->getLogContext()->logNodeDeletion(*s, /* logIntoModule */ true);
60 }
61 }
62
63 /// A helper function to log the creation of constant/placeholder \p s of a
64 /// module into the log context of given functions \p functions.
65 /// Same note as for logStorageDeletion().
logStorageCreation(std::list<Function * > functions,Storage * s)66 void logStorageCreation(std::list<Function *> functions, Storage *s) {
67 for (auto *F : functions) {
68 F->getLogContext()->logNodeCreation(*s);
69 }
70 if (functions.size() > 0) {
71 auto *F = *(functions.begin());
72 F->getLogContext()->logNodeCreation(*s, /* logIntoModule */ true);
73 }
74 }
75 } // namespace
76
77 /// Merge shape \p shape into \p mergeShape, following multidirectional
78 /// broadcasting rules.
mergeMultidirectionalBroadcastHelper(std::vector<dim_t> & mergeShape,llvm::ArrayRef<dim_t> shape)79 static void mergeMultidirectionalBroadcastHelper(std::vector<dim_t> &mergeShape,
80 llvm::ArrayRef<dim_t> shape) {
81 size_t shift = mergeShape.size() - shape.size();
82 for (size_t i = 0, e = shape.size(); i < e; i++) {
83 if (shape[i] == 1) {
84 // Just leave mergeShape[i] as it is.
85 continue;
86 }
87
88 assert(
89 ((shape[i] == mergeShape[shift + i]) || (mergeShape[shift + i] == 1)) &&
90 "Incompatible dimension for the broadcast");
91 mergeShape[shift + i] = shape[i];
92 }
93 }
94
95 /// Utility function which computes the resulting shape in case of
96 /// multidirectional broadcasting.
97 static std::vector<dim_t>
computeMultidirectionalBroadcastHelper(llvm::ArrayRef<dim_t> shape0,llvm::ArrayRef<dim_t> shape1)98 computeMultidirectionalBroadcastHelper(llvm::ArrayRef<dim_t> shape0,
99 llvm::ArrayRef<dim_t> shape1) {
100 size_t numDims0 = shape0.size();
101 size_t numDims1 = shape1.size();
102 size_t newNumDims = std::max(numDims0, numDims1);
103 std::vector<dim_t> reshapeDims(newNumDims, 1);
104
105 mergeMultidirectionalBroadcastHelper(reshapeDims, shape0);
106 mergeMultidirectionalBroadcastHelper(reshapeDims, shape1);
107
108 return reshapeDims;
109 }
110
111 std::vector<NodeValue>
broadcastInputs(int axis,const llvm::ArrayRef<NodeValue> inputs)112 Function::broadcastInputs(int axis, const llvm::ArrayRef<NodeValue> inputs) {
113 dim_t numInputs = inputs.size();
114
115 if (axis > -1) {
116 assert(
117 numInputs == 2 &&
118 "If axis is specified, not -1, unidirectional broadcast will be used, "
119 "input size must be 2.");
120 return {inputs[0],
121 createBroadcast("broadcast_" + inputs[1].getNode()->getName().str(),
122 inputs[1], inputs[0].dims(), axis)};
123 }
124
125 assert(numInputs >= 2 && "Invalid input passed in to commonCreateBroadcast.");
126
127 std::vector<dim_t> targetDim = computeMultidirectionalBroadcastHelper(
128 inputs[0].dims(), inputs[1].dims());
129
130 for (size_t i = 2; i < numInputs; ++i) {
131 targetDim =
132 computeMultidirectionalBroadcastHelper(targetDim, inputs[i].dims());
133 }
134
135 std::vector<NodeValue> out(numInputs);
136 for (size_t i = 0; i < numInputs; ++i) {
137 NodeValue n = inputs[i];
138 auto dims = n.dims();
139 if (dims != llvm::ArrayRef<dim_t>(targetDim)) {
140 unsigned axis = targetDim.size() - dims.size();
141 out[i] = createBroadcast("broadcast_" + n.getNode()->getName().str(), n,
142 targetDim, axis);
143 } else {
144 out[i] = inputs[i];
145 }
146 }
147 return out;
148 }
149
hasFunction(llvm::StringRef name)150 bool Module::hasFunction(llvm::StringRef name) { return getFunction(name); }
151
clearFunctions()152 void Module::clearFunctions() {
153 for (auto *F : functions_) {
154 F->clear();
155 }
156 }
157
clear()158 void Function::clear() {
159 nodes_.clear();
160 uniqueNodeNames_.clear();
161 }
162
getFunction(llvm::StringRef name)163 Function *Module::getFunction(llvm::StringRef name) {
164 for (auto *F : functions_) {
165 if (F->getName() == name) {
166 return F;
167 }
168 }
169 return nullptr;
170 }
171
createFunction(llvm::StringRef name)172 Function *Module::createFunction(llvm::StringRef name) {
173 assert(!hasFunction(name) && "A function with this name already exists");
174 Function *F = new Function(this, name);
175 functions_.push_back(F);
176 return F;
177 }
178
strip()179 void Module::strip() {
180 for (auto it = constants_.begin(), e = constants_.end(); it != e; it++) {
181 Constant *v = *it;
182 v->clearPayload();
183 }
184 }
185
clear()186 void Module::clear() {
187 for (auto it = constants_.begin(), e = constants_.end(); it != e; it++) {
188 Constant *v = *it;
189 logStorageDeletion(functions_, v);
190 delete v;
191 }
192
193 constants_.clear();
194
195 for (auto it = placeholders_.begin(), e = placeholders_.end(); it != e;
196 it++) {
197 Placeholder *p = *it;
198 logStorageDeletion(functions_, p);
199 delete p;
200 }
201
202 eraseFunctions();
203
204 placeholders_.clear();
205 }
206
~Module()207 Module::~Module() { clear(); }
verify() const208 bool Module::verify() const {
209 bool isValid = true;
210 for (auto *F : functions_) {
211 isValid &= F->verify();
212 }
213 // Check that all types used by constants or placeholders belong to the
214 // module.
215 auto &types = getTypes();
216 for (const auto *PH : getPlaceholders()) {
217 bool foundType =
218 std::find(types.begin(), types.end(), *PH->getType()) != types.end();
219 isValid &=
220 expectCompareTrue("Every type used by placeholders should be part of "
221 "the graph",
222 foundType, true, PH);
223 }
224 for (const auto *C : getConstants()) {
225 bool foundType =
226 std::find(types.begin(), types.end(), *C->getType()) != types.end();
227 isValid &=
228 expectCompareTrue("Every type used by constants should be part of "
229 "the graph",
230 foundType, true, C);
231 }
232 return isValid;
233 }
234
dump() const235 void Module::dump() const {
236 llvm::outs() << "Module structure:\n";
237 for (auto *C : getConstants()) {
238 llvm::outs() << C->getDebugDesc() << "\n";
239 }
240
241 for (auto *P : getPlaceholders()) {
242 llvm::outs() << P->getDebugDesc() << "\n";
243 }
244
245 for (auto *F : functions_) {
246 llvm::outs() << "Function:" << F->getName() << "\n";
247 }
248 }
249
toString() const250 std::string Module::toString() const {
251 std::string storage;
252 llvm::raw_string_ostream os(storage);
253 dump(os);
254 return os.str();
255 }
256
257 /// Creates a std::set copy of \p unsorted, sorted based on name of each
258 /// element, and \returns it.
259 template <class T>
getNamedSorted(const std::list<T * > & unsorted)260 static std::set<T *, SortNamed> getNamedSorted(const std::list<T *> &unsorted) {
261 return std::set<T *, SortNamed>(unsorted.begin(), unsorted.end());
262 }
263
dump(llvm::raw_ostream & os) const264 void Module::dump(llvm::raw_ostream &os) const {
265 os << "Module structure:\n";
266 for (auto *C : getNamedSorted(constants_)) {
267 os << C->getDebugDesc() << "\n";
268 }
269 for (auto *P : getNamedSorted(placeholders_)) {
270 os << P->getDebugDesc() << "\n";
271 }
272 for (auto *F : getNamedSorted(functions_)) {
273 os << "Function : " << F->getName() << "\n";
274 }
275 }
276
277 /// A helper class for visiting and generating the dotty graph file.
278 class AbstractDottyPrinter {
279 protected:
280 // List of generated vertices.
281 std::vector<std::string> vertices_{};
282 // List of generated edges.
283 std::unordered_set<std::string> edges_{};
284 // Map node addresses to unique numbers.
285 using VertexNumberMap = std::unordered_map<void *, unsigned>;
286 VertexNumberMap vertex_numbers{};
287
288 /// Dumps label for a input/output row, given port names.
289 /// E.g. {"LHS", "RHS"} will produce {<LHS>LHS|<RHS>RHS}
dumpLabelForRow(llvm::ArrayRef<std::string> names,std::ostream & os)290 void dumpLabelForRow(llvm::ArrayRef<std::string> names, std::ostream &os) {
291 os << "{";
292 for (size_t i = 0; i < names.size(); i++) {
293 if (i) {
294 os << "|";
295 }
296 os << "<" << names[i] << ">" << names[i];
297 }
298 os << "}";
299 }
300
dumpLabel(Node * N,std::ostream & os)301 void dumpLabel(Node *N, std::ostream &os) {
302 os << "{";
303 if (N->getNumInputs()) {
304 std::vector<std::string> names(N->getNumInputs());
305 for (size_t i = 0; i < names.size(); i++) {
306 names[i] = N->getInputName(i);
307 }
308 dumpLabelForRow(names, os);
309 os << "|";
310 }
311 os << "{" << escapeDottyString(N->getDebugDesc()) << "}";
312 if (N->getNumResults()) {
313 os << "|";
314 std::vector<std::string> names(N->getNumResults());
315 for (size_t i = 0; i < names.size(); i++) {
316 names[i] = N->getOutputName(i).str();
317 }
318 dumpLabelForRow(names, os);
319 }
320 os << "}";
321 }
322
dumpNode(Node * N,bool uniqueNames)323 void dumpNode(Node *N, bool uniqueNames) {
324 if (!N) {
325 return;
326 }
327 std::ostringstream os;
328 // Print a node descriptor that looks like this:
329 if (uniqueNames) {
330 // vNNNN [ shape = "record" label = "{...}" ];
331 os << uniqueVertexName(N) << "[\n";
332 } else {
333 // <name> [ shape = "record" label = "{...}" ];
334 os << N->getName().str() << "[\n";
335 }
336 os << "\tlabel = \"";
337 dumpLabel(N, os);
338 os << "\"\n";
339 os << "\tshape = \"record\"\n";
340 os << "\tstyle=\"filled,rounded\"\n";
341
342 // Pick a color based on the node kind.
343 unsigned colorIdx = llvm::hash_value(llvm::StringRef(N->getKindName()));
344 auto nodeColor = getDotFileNodeColor(colorIdx);
345
346 if (isa<Constant>(N)) {
347 os << "\tfillcolor=Snow3 color=DeepSkyBlue4\n";
348 } else {
349 os << "\tfillcolor=" << nodeColor << "\n";
350 }
351 os << "penwidth = 2];\n";
352
353 vertices_.push_back(os.str());
354 }
355
dumpEdgeStyle(const Node * N,size_t i,Node * to,std::ostream & os)356 void dumpEdgeStyle(const Node *N, size_t i, Node *to, std::ostream &os) {
357 if (N->isOverwrittenNthInput(i)) {
358 os << " [dir=\"both\"]";
359 }
360 }
361
uniqueVertexName(void * N)362 std::string uniqueVertexName(void *N) {
363 VertexNumberMap::iterator i;
364 bool inserted;
365 std::tie(i, inserted) = vertex_numbers.insert(std::make_pair(N, 0u));
366 if (inserted) {
367 i->second = vertex_numbers.size() - 1;
368 }
369
370 std::string buffer;
371 llvm::raw_string_ostream stream(buffer);
372 stream << llvm::format("v%04u", i->second);
373 return stream.str();
374 }
375
376 public:
dumpAll(std::ostream & os)377 void dumpAll(std::ostream &os) {
378 CHECK(os) << "Failed to create file for to dump Graph";
379
380 os << "digraph DAG {\n\trankdir=TB;\n";
381
382 // Dump vertices:
383 for (auto &v : vertices_) {
384 os << v << "\n";
385 }
386
387 // Dump edges:
388 for (auto &e : edges_) {
389 os << e << ";\n";
390 }
391
392 os << "}";
393 }
394 };
395
396 class ModuleDottyPrinter : public AbstractDottyPrinter {
397 /// Dump Function as a vertix. Then iterate through constants, used in the
398 /// function, and create corresponding edges.
visitFunction(Function * F)399 void visitFunction(Function *F) {
400 std::ostringstream os;
401 // Print a Function descriptor that looks like this:
402 // vNNNN [ label = "{...}" ];
403 os << uniqueVertexName(F) << "[\n"
404 << "\tlabel = \"Function\\l"
405 << "name : " << F->getName().str() << "\\l"
406 << "node count : " << F->getNodes().size() << "\"\n"
407 << "\tshape = box\n"
408 << "\tfillcolor=gray89, style=\"filled,rounded\"\n"
409 << "\t\n"
410 << "];\n";
411 vertices_.push_back(os.str());
412
413 for (auto &N : F->getNodes()) {
414 for (size_t i = 0; i < N.getNumInputs(); i++) {
415 Node *to = N.getNthInput(i).getNode();
416 size_t resNo = N.getNthInput(i).getResNo();
417
418 if (!isa<Constant>(to))
419 continue;
420
421 std::ostringstream edge;
422 edge << uniqueVertexName(to) << ":" << to->getOutputName(resNo).str()
423 << " -> " << uniqueVertexName(F);
424 dumpEdgeStyle(&N, i, to, edge);
425 edges_.insert(edge.str());
426 }
427 }
428 }
429
430 public:
visitModule(Module * M)431 void visitModule(Module *M) {
432 for (auto N : M->getConstants()) {
433 dumpNode(N, true);
434 }
435
436 for (auto F : M->getFunctions()) {
437 visitFunction(F);
438 }
439 }
440 };
441
442 // TODO: consider refactoring boilerplate code to new trait: DottyPrintable<ADP>
dumpDAG()443 void Module::dumpDAG() {
444 llvm::SmallString<64> dotPath;
445 llvm::sys::fs::createTemporaryFile("dotty_graph_dump", "dot", dotPath);
446 dumpDAG(dotPath);
447 }
448
dumpDAG(llvm::StringRef dotFilename)449 void Module::dumpDAG(llvm::StringRef dotFilename) {
450 llvm::outs() << "Writing dotty graph for Module to: " << dotFilename << '\n';
451
452 ModuleDottyPrinter DP;
453
454 DP.visitModule(this);
455
456 std::ofstream myfile;
457 myfile.open(dotFilename);
458 DP.dumpAll(myfile);
459 myfile.close();
460 }
461
dumpDAG(const char * dotFilename)462 void Module::dumpDAG(const char *dotFilename) {
463 dumpDAG(llvm::StringRef(dotFilename));
464 }
465
eraseFunctions()466 void Module::eraseFunctions() {
467 while (!functions_.empty()) {
468 eraseFunction(*functions_.begin());
469 }
470 }
471
eraseFunction(Function * F)472 void Module::eraseFunction(Function *F) {
473 auto it = std::find(functions_.begin(), functions_.end(), F);
474 assert(it != functions_.end() && "Function is not part of a module");
475 functions_.erase(it);
476 delete F;
477 }
478
getConstantsSize()479 uint64_t Module::getConstantsSize() {
480 uint64_t size = 0;
481 for (auto *constant : constants_) {
482 size += constant->getPayload().getSizeInBytes();
483 }
484 return size;
485 }
486
~Function()487 Function::~Function() {
488 // Delete all of the nodes.
489 for (auto it = nodes_.begin(), e = nodes_.end(); it != e;) {
490 auto cur = it++;
491 eraseNode(&*cur);
492 }
493 }
494
uniqueType(ElemKind elemTy,llvm::ArrayRef<dim_t> dims)495 TypeRef Module::uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims) {
496 return uniqueType(Type(elemTy, dims));
497 }
498
uniqueType(ElemKind elemTy,llvm::ArrayRef<dim_t> dims,float scale,int32_t offset)499 TypeRef Module::uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims,
500 float scale, int32_t offset) {
501 return uniqueType(Type(elemTy, dims, scale, offset));
502 }
503
uniqueTypeWithNewShape(TypeRef T,llvm::ArrayRef<dim_t> dims)504 TypeRef Module::uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims) {
505 return uniqueType(Type::newShape(*T, dims));
506 }
507
uniqueTypeWithNewShape(TypeRef T,llvm::ArrayRef<dim_t> dims,llvm::ArrayRef<dim_t> alignments)508 TypeRef Module::uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims,
509 llvm::ArrayRef<dim_t> alignments) {
510 return uniqueType(Type::newShape(*T, dims, alignments));
511 }
512
uniqueTypeWithNewShape(TypeRef T,TypeRef shapeType)513 TypeRef Module::uniqueTypeWithNewShape(TypeRef T, TypeRef shapeType) {
514 return uniqueType(Type::newShape(*T, shapeType));
515 }
516
uniqueType(const Type & T)517 TypeRef Module::uniqueType(const Type &T) {
518 for (auto &tp : types_) {
519 if (T.isEqual(tp)) {
520 return &tp;
521 }
522 }
523
524 return &*types_.insert(types_.begin(), T);
525 }
526
getVoidTy()527 TypeRef Module::getVoidTy() { return uniqueType(Type()); }
528
529 /// \returns a ShapeVector of rank axes.size() less than the input \p dims,
530 /// where the provided \p axes dimensions are removed from the shape.
getNewShapeWithoutAxes(llvm::ArrayRef<dim_t> dims,llvm::ArrayRef<unsigned_t> axes)531 static ShapeVector getNewShapeWithoutAxes(llvm::ArrayRef<dim_t> dims,
532 llvm::ArrayRef<unsigned_t> axes) {
533 assert(axes.size() <= dims.size() &&
534 "Cannot remove more dimensions than exist.");
535 ShapeVector newDims(dims.begin(), dims.end());
536 ShapeVector shapeAxes(axes.begin(), axes.end());
537
538 // Sort so that looping erase below doesn't fail.
539 std::sort(shapeAxes.rbegin(), shapeAxes.rend());
540
541 for (const auto &axis : shapeAxes) {
542 assert(axis <= dims.size() &&
543 "Axis to remove must fit inside dimensions of the provided dims.");
544 newDims.erase(newDims.begin() + axis);
545 }
546 return newDims;
547 }
548
549 //===----------------------------------------------------------------------===//
550 // Node builders
551 //===----------------------------------------------------------------------===//
552
createPlaceholder(TypeRef T,llvm::StringRef name,bool isTrainable,const std::string & layout)553 Placeholder *Module::createPlaceholder(TypeRef T, llvm::StringRef name,
554 bool isTrainable,
555 const std::string &layout) {
556 auto FT = uniqueType(*T);
557 auto *ph = new Placeholder(name, FT, isTrainable, layout);
558 ph->setName(uniqueName(ph->getName(), usedNodeNames_, usedStorageNames_,
559 originalNames_));
560 placeholders_.push_back(ph);
561 logStorageCreation(functions_, ph);
562 return ph;
563 }
564
createPlaceholder(ElemKind T,llvm::ArrayRef<dim_t> dims,llvm::StringRef name,bool isTrainable,const std::string & layout)565 Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims,
566 llvm::StringRef name, bool isTrainable,
567 const std::string &layout) {
568 auto FT = uniqueType(T, dims);
569 return createPlaceholder(FT, name, isTrainable, layout);
570 }
571
createPlaceholder(ElemKind T,llvm::ArrayRef<dim_t> dims,float scale,int32_t offset,llvm::StringRef name,bool isTrainable,const std::string & layout)572 Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims,
573 float scale, int32_t offset,
574 llvm::StringRef name, bool isTrainable,
575 const std::string &layout) {
576 auto FT = uniqueType(T, dims, scale, offset);
577 return createPlaceholder(FT, name, isTrainable, layout);
578 }
579
createConstant(TypeRef T,llvm::StringRef name,const std::string & layout)580 Constant *Module::createConstant(TypeRef T, llvm::StringRef name,
581 const std::string &layout) {
582 auto FT = uniqueType(*T);
583 return addConstant(new Constant(name, FT, layout));
584 }
585
createConstant(ElemKind T,llvm::ArrayRef<dim_t> dims,llvm::StringRef name,const std::string & layout)586 Constant *Module::createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims,
587 llvm::StringRef name,
588 const std::string &layout) {
589 auto FT = uniqueType(T, dims);
590 return createConstant(FT, name, layout);
591 }
592
createConstant(ElemKind T,llvm::ArrayRef<dim_t> dims,float scale,int32_t offset,llvm::StringRef name,const std::string & layout)593 Constant *Module::createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims,
594 float scale, int32_t offset,
595 llvm::StringRef name,
596 const std::string &layout) {
597 auto FT = uniqueType(T, dims, scale, offset);
598 return createConstant(FT, name, layout);
599 }
600
createConstant(llvm::StringRef name,const Tensor & tensor,const std::string & layout)601 Constant *Module::createConstant(llvm::StringRef name, const Tensor &tensor,
602 const std::string &layout) {
603 auto *V = createConstant(&tensor.getType(), name, layout);
604 V->assign(&tensor);
605 return V;
606 }
607
createConstant(llvm::StringRef name,Tensor && tensor,const std::string & layout)608 Constant *Module::createConstant(llvm::StringRef name, Tensor &&tensor,
609 const std::string &layout) {
610 return addConstant(new Constant(name, std::move(tensor), layout));
611 }
612
getPrefix(llvm::StringRef name)613 std::string Module::getPrefix(llvm::StringRef name) {
614 std::string prefix = name;
615 size_t delim = name.rfind("__");
616 if (delim != std::string::npos &&
617 std::all_of(name.begin() + (delim + 2), name.end(),
618 [](unsigned char c) { return ::isdigit(c); })) {
619 prefix = prefix.substr(0, delim);
620 }
621 return prefix;
622 }
623
uniqueName(llvm::StringRef name,const llvm::StringSet<> & stringTable,llvm::StringSet<> & updateTable,const llvm::StringSet<> & originalNames)624 llvm::StringRef Module::uniqueName(llvm::StringRef name,
625 const llvm::StringSet<> &stringTable,
626 llvm::StringSet<> &updateTable,
627 const llvm::StringSet<> &originalNames) {
628 std::string legalName = legalizeName(name);
629 if (stringTable.find(legalName) == stringTable.end()) {
630 auto it = updateTable.insert(legalName);
631 if (it.second) {
632 return it.first->first();
633 }
634 }
635 // Retain the trailing "__[0-9]+" if it is in the original name.
636 std::string prefix = (originalNames.find(legalName) == originalNames.end())
637 ? Module::getPrefix(legalName)
638 : legalName;
639 for (unsigned i = 1; i < 10000; i++) {
640 auto suffix = std::to_string(i);
641 std::string fullName = prefix + "__" + suffix;
642 if (stringTable.find(fullName) != stringTable.end()) {
643 continue;
644 }
645
646 auto it = updateTable.insert(fullName);
647 if (it.second) {
648 return it.first->first();
649 }
650 }
651 llvm_unreachable("Unable to find a unique a name.");
652 }
653
addConstant(Constant * V)654 Constant *Module::addConstant(Constant *V) {
655 V->setName(uniqueName(V->getName(), usedNodeNames_, usedStorageNames_,
656 originalNames_));
657 // Replace the Constant's output type with the equivalent unique type for
658 // this Module to maintain the invariant that each type in the Module is
659 // unique.
660 V->setType(Constant::ResultIndices::OutputIdx, uniqueType(*V->getType()));
661 constants_.push_back(V);
662 logStorageCreation(functions_, V);
663 return V;
664 }
665
666 /// Check if the 'pads' array has the right size.
assertPadsSize(NodeValue input,llvm::ArrayRef<int> pads)667 static void assertPadsSize(NodeValue input, llvm::ArrayRef<int> pads) {
668 assert((pads.size() == 2 * input.dims().size()) &&
669 "the pads array must contain 2 values per dimensions");
670 }
671
createPad(llvm::StringRef name,NodeValue input,TypeRef outTy,unsigned_t mode,llvm::ArrayRef<int> pads,float value)672 PadNode *Function::createPad(llvm::StringRef name, NodeValue input,
673 TypeRef outTy, unsigned_t mode,
674 llvm::ArrayRef<int> pads, float value) {
675 assertPadsSize(input, pads);
676 auto OT = getParent()->uniqueType(*outTy);
677 return addNode(new PadNode(name, OT, input, mode, pads, value));
678 }
679
680 /// Check the kernel size for Conv/Pooling ops.
checkKernelSize(ShapeNHWC idim,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> pads)681 static void checkKernelSize(ShapeNHWC idim, llvm::ArrayRef<unsigned_t> kernels,
682 llvm::ArrayRef<unsigned_t> pads) {
683 PaddingTLBR pdim(pads);
684 (void)pdim;
685 ShapeHW kdim(kernels);
686 (void)kdim;
687 assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
688 (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
689 "Kernel size is too large");
690 }
691
692 /// Check the kernel size for 3D Conv/Pooling ops.
check3DKernelSize(ShapeNTHWC idim,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> pads)693 static void check3DKernelSize(ShapeNTHWC idim,
694 llvm::ArrayRef<unsigned_t> kernels,
695 llvm::ArrayRef<unsigned_t> pads) {
696 PaddingNFTBLR pdim(pads);
697 (void)pdim;
698 ShapeTHW kdim(kernels);
699 (void)kdim;
700 assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
701 (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
702 (idim.t + pdim.near + pdim.far) >= kdim.temporal_frames &&
703 "Kernel size is too large");
704 }
705
706 /// Check that the dimensions that are passed in when the ConvTranspose is
707 /// constructed are correct.
assertConvTransposeDims(NodeValue input,NodeValue filter,NodeValue bias,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group)708 static void assertConvTransposeDims(NodeValue input, NodeValue filter,
709 NodeValue bias,
710 llvm::ArrayRef<unsigned_t> kernels,
711 llvm::ArrayRef<unsigned_t> strides,
712 llvm::ArrayRef<unsigned_t> pads,
713 unsigned_t group) {
714 ShapeNHWC idim = ShapeNHWC(input.dims());
715 (void)idim;
716 ShapeHW kdim(kernels);
717 (void)kdim;
718 assert(idim.c % group == 0 && "channels number must be divisible by groups");
719
720 // NOTE: here the N in NHWC is abnormal because it is the number of filters
721 // (and therefore the number of output channels of the conv) and not the
722 // batch size. The rest of the dimensions are representative of the input
723 // dimensions to the convolution.
724 ShapeNHWC filterDims(filter.dims());
725 (void)filterDims;
726
727 assert(filterDims.n % group == 0 && filterDims.h == kdim.height &&
728 filterDims.w == kdim.width && filterDims.c == idim.c / group &&
729 "Invalid filter dims");
730
731 assert(bias.getType()->size() == filterDims.n && "Invalid bias size");
732 }
733
734 /// Check that the dimensions that are passed in when the convolution is
735 /// constructed are correct.
assertConvDims(NodeValue input,NodeValue filter,NodeValue bias,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group)736 static void assertConvDims(NodeValue input, NodeValue filter, NodeValue bias,
737 llvm::ArrayRef<unsigned_t> kernels,
738 llvm::ArrayRef<unsigned_t> strides,
739 llvm::ArrayRef<unsigned_t> pads, unsigned_t group) {
740 ShapeNHWC idim = ShapeNHWC(input.dims());
741 ShapeHW kdim(kernels);
742 (void)kdim;
743 checkKernelSize(idim, kernels, pads);
744 assert(idim.c % group == 0 && "channels number must be divisible by groups");
745
746 // NOTE: here the N in NHWC is abnormal because it is the number of filters
747 // (and therefore the number of output channels of the conv) and not the
748 // batch size. The rest of the dimensions are representative of the input
749 // dimensions to the convolution.
750 ShapeNHWC filterDims(filter.dims());
751 (void)filterDims;
752
753 assert(filterDims.n % group == 0 && filterDims.h == kdim.height &&
754 filterDims.w == kdim.width && filterDims.c == idim.c / group &&
755 "Invalid filter dims");
756
757 assert(bias.getType()->size() == filterDims.n && "Invalid bias size");
758 }
759
760 /// Check that the dimensions that are passed in when the 3D convolution is
761 /// constructed are correct.
assertConv3DDims(NodeValue input,NodeValue filter,NodeValue bias,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group)762 static void assertConv3DDims(NodeValue input, NodeValue filter, NodeValue bias,
763 llvm::ArrayRef<unsigned_t> kernels,
764 llvm::ArrayRef<unsigned_t> strides,
765 llvm::ArrayRef<unsigned_t> pads,
766 unsigned_t group) {
767 ShapeNTHWC idim(input.dims());
768 ShapeTHW kdim(kernels);
769 (void)kdim;
770 check3DKernelSize(idim, kernels, pads);
771 assert(idim.c % group == 0 && "channels number must be divisible by groups");
772
773 // NOTE: here the N in NTHWC is abnormal because it is the number of filters
774 // (and therefore the number of output channels of the 3d conv) and not the
775 // batch size. The rest of the dimensions are representative of the input
776 // dimensions to the convolution.
777 ShapeNTHWC filterDims(filter.dims());
778 (void)filterDims;
779
780 assert(filterDims.n % group == 0 && filterDims.h == kdim.height &&
781 filterDims.w == kdim.width && filterDims.t == kdim.temporal_frames &&
782 filterDims.c == idim.c / group && "Invalid filter dims");
783
784 assert(bias.getType()->size() == filterDims.n && "Invalid bias size");
785 }
786
createConv(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group,unsigned_t dilation,ConvolutionLayout layout)787 ConvolutionNode *Function::createConv(
788 llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
789 TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
790 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
791 unsigned_t group, unsigned_t dilation, ConvolutionLayout layout) {
792 assertConvDims(input, filter, bias, kernels, strides, pads, group);
793 auto OT = getParent()->uniqueType(*outTy);
794
795 // If the input is quantized but the bias is not then auto-quantize the
796 // bias.
797 if (input.getType()->isQuantizedType()) {
798 auto biasType = bias.getElementType();
799 if (biasType == ElemKind::Int32QTy || biasType == ElemKind::Int8QTy) {
800 // Nothing to do
801 } else if (biasType == ElemKind::FloatTy) {
802 auto biasTy = getParent()->uniqueType(
803 glow::ElemKind::Int32QTy, bias.dims(),
804 input.getType()->getScale() * filter.getType()->getScale(),
805 /* offset */ 0);
806 bias = createQuantize("quantized_bias", bias, biasTy);
807 } else {
808 LOG(DFATAL)
809 << "Unsupported element type for bias of quantized convolution: "
810 << Type::getElementName(biasType).str();
811 }
812 }
813
814 return addNode(new ConvolutionNode(name, OT, input, filter, bias, kernels,
815 strides, pads, group, dilation, layout,
816 FusedActivation::NONE));
817 }
818
createConv(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group,unsigned_t dilation,ConvolutionLayout layout)819 ConvolutionNode *Function::createConv(llvm::StringRef name, NodeValue input,
820 NodeValue filter, NodeValue bias,
821 TypeRef outTy, unsigned_t kernel,
822 unsigned_t stride, unsigned_t pad,
823 unsigned_t group, unsigned_t dilation,
824 ConvolutionLayout layout) {
825 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
826 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
827 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
828 return createConv(name, input, filter, bias, outTy, kernels, strides, pads,
829 group, dilation, layout);
830 }
831
createConv3D(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group)832 Convolution3DNode *Function::createConv3D(llvm::StringRef name, NodeValue input,
833 NodeValue filter, NodeValue bias,
834 TypeRef outTy,
835 llvm::ArrayRef<unsigned_t> kernels,
836 llvm::ArrayRef<unsigned_t> strides,
837 llvm::ArrayRef<unsigned_t> pads,
838 unsigned_t group) {
839 assertConv3DDims(input, filter, bias, kernels, strides, pads, group);
840 auto OT = getParent()->uniqueType(*outTy);
841
842 // If the input is quantized but the bias is not then auto-quantize the
843 // bias.
844 if (input.getType()->isQuantizedType()) {
845 auto biasType = bias.getElementType();
846 if (biasType == ElemKind::Int32QTy || biasType == ElemKind::Int8QTy ||
847 biasType == ElemKind::Int16QTy) {
848 // Nothing to do
849 } else if (biasType == ElemKind::FloatTy) {
850 auto biasTy = getParent()->uniqueType(
851 glow::ElemKind::Int32QTy, bias.dims(),
852 input.getType()->getScale() * filter.getType()->getScale(),
853 /* offset */ 0);
854 bias = createQuantize("quantized_bias", bias, biasTy);
855 } else {
856 LOG(DFATAL)
857 << "Unsupported element type for bias of quantized convolution: "
858 << Type::getElementName(biasType).str();
859 }
860 }
861 return addNode(new Convolution3DNode(name, OT, input, filter, bias, kernels,
862 strides, pads, group));
863 }
864
createConv3D(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group)865 Convolution3DNode *Function::createConv3D(llvm::StringRef name, NodeValue input,
866 NodeValue filter, NodeValue bias,
867 TypeRef outTy, unsigned_t kernel,
868 unsigned_t stride, unsigned_t pad,
869 unsigned_t group) {
870 llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad};
871 llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride};
872 llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel};
873 return createConv3D(name, input, filter, bias, outTy, kernels, strides, pads,
874 group);
875 }
876
createConvTranspose(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group,unsigned_t dilation)877 ConvTransposeNode *Function::createConvTranspose(
878 llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
879 TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
880 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
881 unsigned_t group, unsigned_t dilation) {
882 assertConvTransposeDims(input, filter, bias, kernels, strides, pads, group);
883 auto OT = getParent()->uniqueType(*outTy);
884 return addNode(new ConvTransposeNode(name, OT, input, filter, bias, kernels,
885 strides, pads, group, dilation));
886 }
887
createConvTranspose(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,TypeRef outTy,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group,unsigned_t dilation)888 ConvTransposeNode *Function::createConvTranspose(
889 llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
890 TypeRef outTy, unsigned_t kernel, unsigned_t stride, unsigned_t pad,
891 unsigned_t group, unsigned_t dilation) {
892 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
893 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
894 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
895 return createConvTranspose(name, input, filter, bias, outTy, kernels, strides,
896 pads, group, dilation);
897 }
898
createMaxPool(llvm::StringRef name,NodeValue input,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,ElemKind elemTyAMT,ConvolutionLayout layout)899 MaxPoolNode *Function::createMaxPool(llvm::StringRef name, NodeValue input,
900 llvm::ArrayRef<unsigned_t> kernels,
901 llvm::ArrayRef<unsigned_t> strides,
902 llvm::ArrayRef<unsigned_t> pads,
903 ElemKind elemTyAMT,
904 ConvolutionLayout layout) {
905 ShapeNHWC idim = ShapeNHWC(input.dims());
906 checkKernelSize(idim, kernels, pads);
907
908 auto outSz =
909 calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads);
910 auto OT = getParent()->uniqueTypeWithNewShape(
911 input.getType(), {idim.n, outSz.first, outSz.second, idim.c});
912 auto AMT = getParent()->uniqueType(
913 elemTyAMT, {idim.n, outSz.first, outSz.second, idim.c});
914
915 return addNode(
916 new MaxPoolNode(name, OT, AMT, input, kernels, strides, pads, layout));
917 }
918
createMaxPool(llvm::StringRef name,NodeValue input,unsigned_t kernel,unsigned_t stride,unsigned_t pad,ElemKind elemTyAMT,ConvolutionLayout layout)919 MaxPoolNode *Function::createMaxPool(llvm::StringRef name, NodeValue input,
920 unsigned_t kernel, unsigned_t stride,
921 unsigned_t pad, ElemKind elemTyAMT,
922 ConvolutionLayout layout) {
923 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
924 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
925 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
926 return createMaxPool(name, input, kernels, strides, pads, elemTyAMT, layout);
927 }
928
createAvgPool(llvm::StringRef name,NodeValue input,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,ConvolutionLayout layout)929 AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input,
930 llvm::ArrayRef<unsigned_t> kernels,
931 llvm::ArrayRef<unsigned_t> strides,
932 llvm::ArrayRef<unsigned_t> pads,
933 ConvolutionLayout layout) {
934 if (!is3DData(layout)) {
935
936 ShapeNHWC idim = ShapeNHWC(input.dims());
937 checkKernelSize(idim, kernels, pads);
938
939 auto outSz =
940 calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads);
941 auto OT = getParent()->uniqueTypeWithNewShape(
942 input.getType(), {idim.n, outSz.first, outSz.second, idim.c});
943 return addNode(
944 new AvgPoolNode(name, OT, input, kernels, strides, pads, layout));
945
946 } else {
947 ShapeNTHWC idim = ShapeNTHWC(input.dims());
948 check3DKernelSize(idim, kernels, pads);
949
950 auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels,
951 strides, pads);
952 auto OT = getParent()->uniqueTypeWithNewShape(
953 input.getType(),
954 {idim.n, outSz.temporal_frames, outSz.height, outSz.width, idim.c});
955 return addNode(
956 new AvgPoolNode(name, OT, input, kernels, strides, pads, layout));
957 }
958 }
959
createAvgPool(llvm::StringRef name,NodeValue input,TypeRef outTy,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,ConvolutionLayout layout)960 AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input,
961 TypeRef outTy,
962 llvm::ArrayRef<unsigned_t> kernels,
963 llvm::ArrayRef<unsigned_t> strides,
964 llvm::ArrayRef<unsigned_t> pads,
965 ConvolutionLayout layout) {
966 if (!is3DData(layout)) {
967
968 ShapeNHWC idim = ShapeNHWC(input.dims());
969 ShapeHW kdim(kernels);
970 (void)kdim;
971 checkKernelSize(idim, kernels, pads);
972 return addNode(
973 new AvgPoolNode(name, outTy, input, kernels, strides, pads, layout));
974
975 } else {
976
977 ShapeNTHWC idim = ShapeNTHWC(input.dims());
978 ShapeTHW kdim(kernels);
979 (void)kdim;
980 check3DKernelSize(idim, kernels, pads);
981 return addNode(
982 new AvgPoolNode(name, outTy, input, kernels, strides, pads, layout));
983 }
984 }
985
createAvgPool(llvm::StringRef name,NodeValue input,unsigned_t kernel,unsigned_t stride,unsigned_t pad,ConvolutionLayout layout)986 AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input,
987 unsigned_t kernel, unsigned_t stride,
988 unsigned_t pad, ConvolutionLayout layout) {
989 if (!is3DData(layout)) {
990
991 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
992 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
993 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
994 return createAvgPool(name, input, kernels, strides, pads, layout);
995
996 } else {
997
998 llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad};
999 llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride};
1000 llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel};
1001 return createAvgPool(name, input, kernels, strides, pads, layout);
1002 }
1003 }
1004
createAdaptiveAvgPool(llvm::StringRef name,NodeValue input,TypeRef outTy)1005 AdaptiveAvgPoolNode *Function::createAdaptiveAvgPool(llvm::StringRef name,
1006 NodeValue input,
1007 TypeRef outTy) {
1008 return addNode(new AdaptiveAvgPoolNode(name, outTy, input));
1009 }
1010
createGemm(llvm::StringRef name,NodeValue A,NodeValue B,NodeValue C,float alpha,float beta,bool transposeA,bool transposeB)1011 GemmNode *Function::createGemm(llvm::StringRef name, NodeValue A, NodeValue B,
1012 NodeValue C, float alpha, float beta,
1013 bool transposeA, bool transposeB) {
1014 std::vector<dim_t> outDims(2);
1015 outDims[0] = transposeA ? A.dims()[1] : A.dims()[0];
1016 outDims[1] = transposeB ? B.dims()[0] : B.dims()[1];
1017 TypeRef outTy = getParent()->uniqueTypeWithNewShape(A.getType(), outDims);
1018 return createGemm(name, outTy, A, B, C, alpha, beta, transposeA, transposeB);
1019 }
1020
createGemm(llvm::StringRef name,TypeRef outTy,NodeValue A,NodeValue B,NodeValue C,float alpha,float beta,bool transposeA,bool transposeB)1021 GemmNode *Function::createGemm(llvm::StringRef name, TypeRef outTy, NodeValue A,
1022 NodeValue B, NodeValue C, float alpha,
1023 float beta, bool transposeA, bool transposeB) {
1024 // If C operand is not given then we create a 1D splat with 0.
1025 if (!C.getNode()) {
1026 TypeRef splatTy =
1027 getParent()->uniqueTypeWithNewShape(outTy, {outTy->dims()[1]});
1028 C = createSplat(name.str() + ".SplatC", splatTy, 0.0f);
1029 }
1030 // If C operand is a 2D constant we check if it is a broadcasted version of
1031 // a 1D tensor. If yes then we slice and reshape the C operand to 1D.
1032 if (auto *constC = llvm::dyn_cast<Constant>(C.getNode())) {
1033 if ((constC->dims().size() == 2) && (constC->getPayload().isTiled(0))) {
1034 // Slice and reshape to 1D.
1035 dim_t lengthC = constC->dims()[1];
1036 C = createSlice(name.str() + ".SliceC", C, {0, 0}, {1, lengthC});
1037 C = createReshape(name.str() + ".ReshapeC", C, {lengthC});
1038 }
1039 }
1040 TypeRef OT = getParent()->uniqueType(*outTy);
1041 return addNode(
1042 new GemmNode(name, OT, A, B, C, alpha, beta, transposeA, transposeB));
1043 }
1044
createFullyConnected(llvm::StringRef name,NodeValue input,Storage * W,Storage * B,unsigned_t axis)1045 FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name,
1046 NodeValue input, Storage *W,
1047 Storage *B,
1048 unsigned_t axis) {
1049 TypeRef T = input.getType();
1050 TypeRef OT = getParent()->uniqueTypeWithNewShape(
1051 T, {input.dims()[0], B->getType()->dims()[0]});
1052
1053 return createFullyConnected(name, input, W, B, OT, axis);
1054 }
1055
createFullyConnected(llvm::StringRef name,NodeValue input,NodeValue W,NodeValue B,unsigned_t axis)1056 FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name,
1057 NodeValue input, NodeValue W,
1058 NodeValue B,
1059 unsigned_t axis) {
1060 TypeRef T = input.getType();
1061 TypeRef OT =
1062 getParent()->uniqueTypeWithNewShape(T, {input.dims()[0], B.dims()[0]});
1063
1064 return createFullyConnected(name, input, W, B, OT, axis);
1065 }
1066
createFullyConnected(llvm::StringRef name,NodeValue input,NodeValue W,NodeValue B,TypeRef outTy,unsigned_t axis)1067 FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name,
1068 NodeValue input, NodeValue W,
1069 NodeValue B, TypeRef outTy,
1070 unsigned_t axis) {
1071 assert(outTy->dims().size() == 2 && "Invalid number of dimensions");
1072 assert(outTy->dims()[0] == input.dims()[0] && "Invalid dimensions");
1073
1074 // FC always uses 2D input; flatten if necessary.
1075 if (input.dims().size() != 2) {
1076 input = createFlatten(name.str() + ".reshape2D", input, axis);
1077 }
1078
1079 TypeRef OT = getParent()->uniqueType(*outTy);
1080 return addNode(new FullyConnectedNode(name, OT, input, W, B));
1081 }
1082
1083 RowwiseQuantizedFullyConnectedNode *
createRowwiseQuantizedFullyConnected(llvm::StringRef name,NodeValue input,Constant * W,Constant * scales,Constant * offsets,NodeValue B,TypeRef outTy)1084 Function::createRowwiseQuantizedFullyConnected(llvm::StringRef name,
1085 NodeValue input, Constant *W,
1086 Constant *scales,
1087 Constant *offsets, NodeValue B,
1088 TypeRef outTy) {
1089 return addNode(new RowwiseQuantizedFullyConnectedNode(name, outTy, input, W,
1090 scales, offsets, B));
1091 }
1092
1093 RowwiseQuantizedFullyConnectedNode *
createRowwiseQuantizedFullyConnected(llvm::StringRef name,NodeValue input,Constant * W,NodeValue B,TypeRef outTy,quantization::Schema schema,bool transposeWeight)1094 Function::createRowwiseQuantizedFullyConnected(llvm::StringRef name,
1095 NodeValue input, Constant *W,
1096 NodeValue B, TypeRef outTy,
1097 quantization::Schema schema,
1098 bool transposeWeight) {
1099 // Since W is constant, quantize it in compilation time.
1100 // The quantized data is in qWeights, the scale of each row is in scales,
1101 // and the offset of each row is in offsets.
1102 Constant *weights = llvm::cast<Constant>(W);
1103 dim_t numRows =
1104 transposeWeight ? W->getType()->dims()[1] : W->getType()->dims()[0];
1105 dim_t numCols =
1106 transposeWeight ? W->getType()->dims()[0] : W->getType()->dims()[1];
1107
1108 // So far, if we want to create a storage with Int8QTy/Int16QTy,
1109 // it is assumed to be quantized data and the scale and offset should be
1110 // provided. But for rowwise quantization, the scales and offsets are stored
1111 // in vectors separately, we add the dummy scale and offset here.
1112 auto *qWeights = getParent()->createConstant(
1113 ElemKind::Int8QTy, {numRows, numCols}, 0.0, 0, "weights.rwqfc");
1114 auto *scales =
1115 getParent()->createConstant(ElemKind::FloatTy, {numRows}, "scales.rwqfc");
1116 auto *offsets = getParent()->createConstant(ElemKind::Int32ITy, {numRows},
1117 "offsets.rwqfc");
1118
1119 Tensor wt;
1120 if (transposeWeight) {
1121 // This happens when the RowwiseQuantizedFullyConnected node is converted
1122 // from a quantized FullyConnected node in Glow's quantization procedure.
1123 // Since in FC, the weights is stored as transposed (i.e. I * W + B), but
1124 // in RowwiseQuantizedFullyConnected, the weights is stored as it is (i.e.
1125 // I * W(T) + B).
1126 weights->getPayloadMutable().transpose(&wt, {1, 0});
1127 } else {
1128 wt.assign(&(weights->getPayload()));
1129 }
1130
1131 // Note: Using int32_t offset here as that is what RWQ-FC expects.
1132 quantization::tensorRowwiseQuantization<float, int32_t, int8_t>(
1133 wt, qWeights->getPayloadMutable(), scales->getPayloadMutable(),
1134 offsets->getPayloadMutable(), schema);
1135
1136 return addNode(new RowwiseQuantizedFullyConnectedNode(
1137 name, outTy, input, qWeights, scales, offsets, B));
1138 }
1139
createRELU(llvm::StringRef name,NodeValue input,TypeRef outTy)1140 ReluNode *Function::createRELU(llvm::StringRef name, NodeValue input,
1141 TypeRef outTy) {
1142 return addNode(new ReluNode(name, outTy, input));
1143 }
1144
createRELU(llvm::StringRef name,NodeValue input)1145 ReluNode *Function::createRELU(llvm::StringRef name, NodeValue input) {
1146 return addNode(new ReluNode(name, input.getType(), input));
1147 }
1148
createGELU(llvm::StringRef name,NodeValue input)1149 Node *Function::createGELU(llvm::StringRef name, NodeValue input) {
1150 auto outTy = input.getType();
1151
1152 Node *alphaSplat =
1153 createSplat(name.str() + ".alpha", outTy, M_2_SQRTPI * M_SQRT1_2);
1154 Node *splat = createSplat(name.str() + ".splat", outTy, 0.044715);
1155 Node *splatHalf = createSplat(name.str() + ".splatHalf", outTy, 0.5);
1156 Node *splat1 = createSplat(name.str() + ".splat3", outTy, 1.0);
1157 Node *splat3 = createSplat(name.str() + ".splat3", outTy, 3.0);
1158
1159 // pow(x, 3)
1160 Node *pow = createPow(name.str() + ".pow", input, splat3);
1161
1162 // pow(x, 3) * 0.044715
1163 Node *mul = createMul(name.str() + ".mul", pow, splat);
1164
1165 // x + pow(x, 3) * 0.044715
1166 Node *add = createAdd(name.str() + ".add", input, mul);
1167
1168 // (x * pow(x, 3) * 0.044715) * alpha
1169 Node *mul2 = createMul(name.str() + ".mul2", add, alphaSplat);
1170
1171 // tanh((x * pow(x, 3) * 0.044715) * alpha)
1172 Node *tanh = createTanh(name.str() + ".tanh", mul2);
1173
1174 // tanh((x * pow(x, 3) * 0.044715) * alpha) + 1
1175 Node *add2 = createAdd(name.str() + ".add2", tanh, splat1);
1176
1177 // (tanh((x * pow(x, 3) * 0.044715) * alpha) + 1) * 0.5
1178 Node *mul3 = createMul(name.str() + ".mul3", splatHalf, add2);
1179
1180 // (tanh((x * pow(x, 3) * 0.044715) * alpha) + 1) * 0.5 * x
1181 return createMul(name.str() + ".mul4", mul3, input);
1182 }
1183
createPRELU(llvm::StringRef name,NodeValue input,NodeValue slope,TypeRef outTy)1184 PReluNode *Function::createPRELU(llvm::StringRef name, NodeValue input,
1185 NodeValue slope, TypeRef outTy) {
1186 return addNode(new PReluNode(name, outTy, input, slope));
1187 }
1188
createPRELU(llvm::StringRef name,NodeValue input,NodeValue slope)1189 PReluNode *Function::createPRELU(llvm::StringRef name, NodeValue input,
1190 NodeValue slope) {
1191 return addNode(new PReluNode(name, input.getType(), input, slope));
1192 }
1193
createSigmoid(llvm::StringRef name,TypeRef outTy,NodeValue input)1194 SigmoidNode *Function::createSigmoid(llvm::StringRef name, TypeRef outTy,
1195 NodeValue input) {
1196 return addNode(new SigmoidNode(name, outTy, input));
1197 }
1198
createSigmoid(llvm::StringRef name,NodeValue input)1199 SigmoidNode *Function::createSigmoid(llvm::StringRef name, NodeValue input) {
1200 return createSigmoid(name, input.getType(), input);
1201 }
1202
createSwish(llvm::StringRef name,NodeValue input)1203 SwishNode *Function::createSwish(llvm::StringRef name, NodeValue input) {
1204 return addNode(new SwishNode(name, input.getType(), input));
1205 }
1206
createTanh(llvm::StringRef name,TypeRef outTy,NodeValue input)1207 TanhNode *Function::createTanh(llvm::StringRef name, TypeRef outTy,
1208 NodeValue input) {
1209 return addNode(new TanhNode(name, outTy, input));
1210 }
1211
createTanh(llvm::StringRef name,NodeValue input)1212 TanhNode *Function::createTanh(llvm::StringRef name, NodeValue input) {
1213 return createTanh(name, input.getType(), input);
1214 }
1215
createSoftMax(llvm::StringRef name,NodeValue input,NodeValue selected,TypeRef outTy,float beta)1216 SoftMaxNode *Function::createSoftMax(llvm::StringRef name, NodeValue input,
1217 NodeValue selected, TypeRef outTy,
1218 float beta) {
1219 // Create input multiplier with beta.
1220 if (beta != 1.0) {
1221 auto *splat = createSplat(name, input.getType(), 1);
1222 input = createMul(name, input, splat);
1223 }
1224 // By default, pick the input type.
1225 if (!outTy) {
1226 outTy = getParent()->uniqueType(*input.getType());
1227 }
1228 return addNode(new SoftMaxNode(name, outTy, input, selected));
1229 }
1230
createCrossEntropyLoss(llvm::StringRef name,NodeValue input,NodeValue labels)1231 CrossEntropyLossNode *Function::createCrossEntropyLoss(llvm::StringRef name,
1232 NodeValue input,
1233 NodeValue labels) {
1234 auto ty = getParent()->uniqueTypeWithNewShape(input.getType(), {1});
1235 return addNode(new CrossEntropyLossNode(name, ty, input, labels));
1236 }
1237
createRegression(llvm::StringRef name,NodeValue input,NodeValue expected)1238 RegressionNode *Function::createRegression(llvm::StringRef name,
1239 NodeValue input,
1240 NodeValue expected) {
1241 return addNode(new RegressionNode(name, input, expected));
1242 }
1243
1244 SigmoidCrossEntropyWithLogitsNode *
createSigmoidCrossEntropyWithLogits(llvm::StringRef name,NodeValue logits,NodeValue targets)1245 Function::createSigmoidCrossEntropyWithLogits(llvm::StringRef name,
1246 NodeValue logits,
1247 NodeValue targets) {
1248 assert(logits.dims().size() > 1);
1249 std::vector<dim_t> outDims(logits.dims().begin(), logits.dims().end() - 1);
1250 auto ty = getParent()->uniqueTypeWithNewShape(logits.getType(), outDims);
1251 return addNode(
1252 new SigmoidCrossEntropyWithLogitsNode(name, ty, logits, targets));
1253 }
1254
createReshape(llvm::StringRef name,NodeValue input,llvm::ArrayRef<dim_t> shape,llvm::StringRef layout)1255 ReshapeNode *Function::createReshape(llvm::StringRef name, NodeValue input,
1256 llvm::ArrayRef<dim_t> shape,
1257 llvm::StringRef layout) {
1258 auto TR = getParent()->uniqueTypeWithNewShape(input.getType(), shape);
1259 DCHECK_EQ(TR->size(), input.getType()->size())
1260 << "Reshape to a different size";
1261 return addNode(new ReshapeNode(name, TR, input, shape.vec(), layout));
1262 }
1263
createTranspose(llvm::StringRef name,NodeValue input,llvm::ArrayRef<unsigned_t> shuffle,const std::string & layout)1264 TransposeNode *Function::createTranspose(llvm::StringRef name, NodeValue input,
1265 llvm::ArrayRef<unsigned_t> shuffle,
1266 const std::string &layout) {
1267 ShapeVector shape;
1268 auto dims = input.dims();
1269 for (size_t i = 0; i < dims.size(); i++) {
1270 shape.push_back(dims[shuffle[i]]);
1271 }
1272
1273 // If the layout is known, check that it matches the shuffle:
1274 auto compareShuffle = [&](const std::vector<unsigned_t> targetShuffle) {
1275 auto shuffleVec = shuffle.vec();
1276 return targetShuffle.size() == dims.size() &&
1277 std::equal(shuffleVec.begin(), shuffleVec.end(),
1278 targetShuffle.begin());
1279 };
1280
1281 auto currLayout = layout;
1282 if (currLayout == ANY_LAYOUT) {
1283 // If layout got a default value, change it based on shuffle:
1284 // TODO: remove the shuffle and replace it with layout.
1285 if (compareShuffle(NCHW2NHWC) || compareShuffle(HWCN2NHWC)) {
1286 currLayout = "NHWC";
1287 } else if (compareShuffle(NCTHW2NTHWC)) {
1288 currLayout = "NTHWC";
1289 } else if (compareShuffle(NHWC2NCHW)) {
1290 currLayout = "NCHW";
1291 } else if (compareShuffle(NTHWC2NCTHW)) {
1292 currLayout = "NCTHW";
1293 } else if (compareShuffle(NHWC2HWNC)) {
1294 currLayout = "HWNC";
1295 } else if (compareShuffle(CNHW2NHWC)) {
1296 currLayout = "NHWC";
1297 }
1298 }
1299
1300 auto NT = getParent()->uniqueTypeWithNewShape(input.getType(), shape);
1301 return addNode(new TransposeNode(name, NT, input, shuffle.vec(), currLayout));
1302 }
1303
createFlip(llvm::StringRef name,NodeValue input,unsigned_t axis)1304 FlipNode *Function::createFlip(llvm::StringRef name, NodeValue input,
1305 unsigned_t axis) {
1306 auto OT = getParent()->uniqueType(*input.getType());
1307 return addNode(new FlipNode(name, OT, input, axis));
1308 }
1309
createBroadcast(llvm::StringRef name,NodeValue input,UnsignedArrayRef newShape,unsigned_t axis)1310 Node *Function::createBroadcast(llvm::StringRef name, NodeValue input,
1311 UnsignedArrayRef newShape, unsigned_t axis) {
1312 const auto &origDims = input.dims();
1313
1314 assert(axis + origDims.size() <= newShape.size() &&
1315 "Axis must fit inside the newShape.");
1316
1317 // Iterate over the new shape; if the original shape had a dimension here
1318 // (when considering the axis) then verify the dimension either matches the
1319 // new shape (no action taken) or == 1 (broadcast in that direction). Else
1320 // the original shape had no dimensions here (after considering axis), so
1321 // add the new dimension and broadcast in that direction.
1322 dim_t reshapeDims[max_tensor_dimensions];
1323 for (dim_t i = 0; i < newShape.size(); i++) {
1324 if (i >= axis && i < origDims.size() + axis) {
1325 const int origIdx = i - axis;
1326 if (origDims[origIdx] == newShape[i]) {
1327 // Keep original dimensions; they are compatible.
1328 reshapeDims[i] = origDims[origIdx];
1329 } else if (origDims[origIdx] == 1) {
1330 // Will broadcast this dimension to size from newShape.
1331 reshapeDims[i] = 1;
1332 } else {
1333 // Incompatible dimensions for broadcasting
1334 llvm_unreachable("Cannot broadcast with these dimensions.");
1335 }
1336 } else {
1337 // Will broadcast this dimension to size from newShape.
1338 reshapeDims[i] = 1;
1339 }
1340 }
1341
1342 // Reshape the input node to same number of dimensions as new shape, but
1343 // with 1s in place of to-be-broadcasted dimensions.
1344 Node *currNode =
1345 createReshape(name.str() + ".reshape", input,
1346 llvm::ArrayRef<dim_t>(reshapeDims, newShape.size()));
1347
1348 // Create a Tile (which is really a Concat) in each direction that needs to
1349 // be broadcasted.
1350 for (size_t i = 0; i < newShape.size(); i++) {
1351 if (reshapeDims[i] == 1 && newShape[i] != 1) {
1352 currNode = createTile(name.str() + ".tile" + std::to_string(i), currNode,
1353 newShape[i], i);
1354 }
1355 }
1356
1357 return currNode;
1358 }
1359
1360 /// \returns true if \p T1 and T2 has the exact same type except for dimension
1361 /// \p dim.
sameSameShapeExceptDim(TypeRef T1,TypeRef T2,unsigned dim)1362 static bool sameSameShapeExceptDim(TypeRef T1, TypeRef T2, unsigned dim) {
1363 if (T1->getElementType() != T2->getElementType()) {
1364 return false;
1365 }
1366
1367 auto D1 = T1->dims();
1368 auto D2 = T2->dims();
1369
1370 if (D1.size() != D2.size()) {
1371 return false;
1372 }
1373
1374 for (unsigned i = 0, e = D1.size(); i < e; i++) {
1375 // Ignore the dimension \p dim.
1376 if (i == dim) {
1377 continue;
1378 }
1379
1380 if (D1[i] != D2[i]) {
1381 return false;
1382 }
1383 }
1384
1385 return true;
1386 }
1387
createConcat(llvm::StringRef name,llvm::ArrayRef<NodeValue> inputs,unsigned_t dimension)1388 ConcatNode *Function::createConcat(llvm::StringRef name,
1389 llvm::ArrayRef<NodeValue> inputs,
1390 unsigned_t dimension) {
1391 for (int i = 1, e = inputs.size(); i < e; i++) {
1392 assert(sameSameShapeExceptDim(inputs[i].getType(), inputs[0].getType(),
1393 dimension) &&
1394 "Invalid type");
1395 (void)sameSameShapeExceptDim;
1396 }
1397 auto inDim = inputs[0].dims();
1398
1399 ShapeVector shape(inDim.begin(), inDim.end());
1400
1401 // We are stacking the tensors along a specific dimension. This means that
1402 // we increase the size of the tensor along this dimension.
1403 shape[dimension] = 0;
1404 for (auto I : inputs) {
1405 shape[dimension] += I.getType()->dims()[dimension];
1406 }
1407
1408 auto NT = getParent()->uniqueTypeWithNewShape(inputs[0].getType(), shape);
1409 std::vector<NodeValue> ops;
1410 ops.reserve(inputs.size());
1411 for (auto I : inputs) {
1412 ops.emplace_back(I);
1413 }
1414 return addNode(new ConcatNode(name, NT, ops, dimension));
1415 }
1416
createConcat(llvm::StringRef name,llvm::ArrayRef<NodeValue> inputs,unsigned_t dimension,TypeRef outTy)1417 ConcatNode *Function::createConcat(llvm::StringRef name,
1418 llvm::ArrayRef<NodeValue> inputs,
1419 unsigned_t dimension, TypeRef outTy) {
1420 std::vector<NodeValue> ops;
1421 ops.reserve(inputs.size());
1422 for (auto I : inputs) {
1423 ops.emplace_back(I);
1424 }
1425
1426 TypeRef OT = getParent()->uniqueType(*outTy);
1427 return addNode(new ConcatNode(name, OT, ops, dimension));
1428 }
1429
createTile(llvm::StringRef name,NodeValue input,unsigned_t tiles,unsigned_t axis,TypeRef outTy)1430 TileNode *Function::createTile(llvm::StringRef name, NodeValue input,
1431 unsigned_t tiles, unsigned_t axis,
1432 TypeRef outTy) {
1433 assert(tiles > 0 && "Tiles must be non-zero.");
1434 assert(axis >= 0 && axis < input.dims().size() &&
1435 "Axis must fall in range of source dims.");
1436
1437 if (outTy == nullptr) {
1438 ShapeVector outShape(input.dims().begin(), input.dims().end());
1439 outShape[axis] *= tiles;
1440 outTy = getParent()->uniqueTypeWithNewShape(input.getType(), outShape);
1441 }
1442
1443 return addNode(new TileNode(name, outTy, input, tiles, axis));
1444 }
1445
createInsertTensor(llvm::StringRef name,NodeValue big,NodeValue small,llvm::ArrayRef<dim_t> start,unsigned_t count,unsigned_t axis)1446 InsertTensorNode *Function::createInsertTensor(llvm::StringRef name,
1447 NodeValue big, NodeValue small,
1448 llvm::ArrayRef<dim_t> start,
1449 unsigned_t count,
1450 unsigned_t axis) {
1451 return addNode(new InsertTensorNode(name, big, small, start, count, axis));
1452 }
1453
createSlice(llvm::StringRef name,NodeValue input,llvm::ArrayRef<dim_t> start,TypeRef outTy)1454 SliceNode *Function::createSlice(llvm::StringRef name, NodeValue input,
1455 llvm::ArrayRef<dim_t> start, TypeRef outTy) {
1456 assert(input.dims().size() == start.size() &&
1457 "Start and input dims should match");
1458 assert(outTy->dims().size() == start.size() &&
1459 "Output and start dims should match");
1460
1461 for (unsigned i = 0, e = input.dims().size(); i < e; i++) {
1462 assert(start[i] + outTy->dims()[i] <= input.dims()[i] &&
1463 "Input/Output/Start dims mismatch");
1464 }
1465
1466 TypeRef OT = getParent()->uniqueType(*outTy);
1467 return addNode(new SliceNode(name, OT, input, start));
1468 }
1469
createSlice(llvm::StringRef name,NodeValue input,llvm::ArrayRef<dim_t> begin,llvm::ArrayRef<dim_t> end)1470 SliceNode *Function::createSlice(llvm::StringRef name, NodeValue input,
1471 llvm::ArrayRef<dim_t> begin,
1472 llvm::ArrayRef<dim_t> end) {
1473 std::vector<dim_t> beginV, shape;
1474 auto dims = input.dims();
1475 assert(begin.size() == end.size() && "Begin and End dimensions should match");
1476 assert(begin.size() == dims.size() &&
1477 "Begin and Input dimensions should match");
1478 for (unsigned i = 0; i < dims.size(); i++) {
1479 dim_t beginI = begin[i];
1480 dim_t endI = end[i];
1481 dim_t dimI = dims[i];
1482 (void)dimI;
1483 assert(beginI >= 0 && "Illegal Begin indices");
1484 assert(endI > 0 && "Illegal End indices");
1485 assert(beginI < dimI && "Illegal Begin indices");
1486 assert(endI <= dimI && "Illegal End indices");
1487 assert(endI > beginI && "Illegal Begin and End indices");
1488 beginV.push_back(beginI);
1489 shape.push_back(endI - beginI);
1490 }
1491
1492 auto NT = getParent()->uniqueTypeWithNewShape(input.getType(), shape);
1493 return addNode(new SliceNode(name, NT, input, beginV));
1494 }
1495
createChannelShuffle(llvm::StringRef name,NodeValue input,size_t group,size_t kernel)1496 Node *Function::createChannelShuffle(llvm::StringRef name, NodeValue input,
1497 size_t group, size_t kernel) {
1498 return addNode(
1499 new ChannelShuffleNode(name, input.getType(), input, group, kernel));
1500 }
1501
createSqueeze(llvm::StringRef name,NodeValue input,llvm::ArrayRef<dim_t> axes)1502 ReshapeNode *Function::createSqueeze(llvm::StringRef name, NodeValue input,
1503 llvm::ArrayRef<dim_t> axes) {
1504 assert(!axes.empty() && "Parameter `axes` must be provided.");
1505
1506 ShapeVector shapeAxes(axes.begin(), axes.end());
1507
1508 // Sort and unique the values in axes to
1509 // 1. make sure each dim is only removed once;
1510 // 2. check if the size and value of dimensions to squeeze are valid.
1511 std::sort(shapeAxes.begin(), shapeAxes.end());
1512 shapeAxes.erase(std::unique(shapeAxes.begin(), shapeAxes.end()),
1513 shapeAxes.end());
1514 auto inDims = input.dims();
1515 assert(shapeAxes.back() < inDims.size() && "The size and value of dimensions "
1516 "to squeeze must be less than the "
1517 "input size.");
1518
1519 ShapeVector newDims;
1520 size_t j = 0;
1521 for (size_t i = 0, e = inDims.size(); i < e; i++) {
1522 if (j < shapeAxes.size() && shapeAxes[j] == i) {
1523 assert(inDims[i] == 1 && "The dimension to squeeze must be 1.");
1524 j++;
1525 } else {
1526 newDims.push_back(inDims[i]);
1527 }
1528 }
1529 return createReshape(name.str() + ".reshape", input, newDims);
1530 }
1531
createExpandDims(llvm::StringRef name,NodeValue input,llvm::ArrayRef<dim_t> axes)1532 ReshapeNode *Function::createExpandDims(llvm::StringRef name, NodeValue input,
1533 llvm::ArrayRef<dim_t> axes) {
1534 assert(!axes.empty() && "Parameter `axes` must be provided.");
1535
1536 // Dimensions provided in axes are for the output tensor, so we sort them
1537 // and unique them to make sure they are processed correctly and in the
1538 // right order.
1539 ShapeVector shapeAxes(axes.begin(), axes.end());
1540 std::sort(shapeAxes.begin(), shapeAxes.end());
1541 shapeAxes.erase(std::unique(shapeAxes.begin(), shapeAxes.end()),
1542 shapeAxes.end());
1543
1544 const auto inDims = input.dims();
1545
1546 // The total number of dimensions in the new shape is equal to the original
1547 // shape size plus the uniqued new shape axes, which represents where to
1548 // insert dimensions of 1 into the output tensor's shape.
1549 const size_t totalNumNewDims = shapeAxes.size() + inDims.size();
1550 assert(totalNumNewDims <= max_tensor_dimensions &&
1551 "New expanded shape has too many dimensions.");
1552 assert(shapeAxes.back() < totalNumNewDims &&
1553 "Specified axis expands outside size of output tensor shape.");
1554 ShapeVector newDims;
1555 for (size_t i = 0, j = 0, k = 0; k < totalNumNewDims; k++) {
1556 if (j < shapeAxes.size() && shapeAxes[j] == k) {
1557 newDims.push_back(1);
1558 j++;
1559 } else {
1560 assert(i < inDims.size() && "Somehow overflowing inDims.");
1561 newDims.push_back(inDims[i]);
1562 i++;
1563 }
1564 }
1565
1566 // Create a reshape of the original data with the newly determined
1567 // dimensions.
1568 return createReshape(name.str() + ".expanddims", input, newDims);
1569 }
1570
createFlatten(llvm::StringRef name,NodeValue input,unsigned_t axis)1571 ReshapeNode *Function::createFlatten(llvm::StringRef name, NodeValue input,
1572 unsigned_t axis) {
1573 auto xDim = flattenCdr(input.getType()->dims(), axis);
1574 return createReshape(name, input, {xDim.first, xDim.second});
1575 }
1576
createSplit(llvm::StringRef name,NodeValue input,unsigned_t outputNum,unsigned_t axis,llvm::ArrayRef<dim_t> split,std::vector<SliceNode * > & outputs)1577 void Function::createSplit(llvm::StringRef name, NodeValue input,
1578 unsigned_t outputNum, unsigned_t axis,
1579 llvm::ArrayRef<dim_t> split,
1580 std::vector<SliceNode *> &outputs) {
1581 auto inDims = input.dims();
1582 if (split.empty()) {
1583 assert(inDims[axis] % outputNum == 0 &&
1584 "Dimension to split must be divisible by outputs number.");
1585 } else {
1586 assert(outputNum == split.size() &&
1587 "Number of splits must be divisible by outputs number.");
1588 }
1589
1590 ShapeVector start(inDims.size(), 0);
1591 ShapeVector end(inDims.begin(), inDims.end());
1592 end[axis] = 0;
1593
1594 outputs.resize(outputNum);
1595 for (size_t i = 0; i < outputNum; i++) {
1596 size_t curLength = split.empty() ? inDims[axis] / outputNum : split[i];
1597 end[axis] += curLength;
1598 outputs[i] =
1599 createSlice(name.str() + ".out" + std::to_string(i), input, start, end);
1600 start[axis] = end[axis];
1601 }
1602
1603 assert(end[axis] == inDims[axis] &&
1604 "Total size of results must be equal to input size.");
1605 }
1606
createBatchNormalization(llvm::StringRef name,NodeValue input,NodeValue beta,NodeValue scale,NodeValue mean,NodeValue var,unsigned_t channelIdx,float epsilon,float momentum)1607 BatchNormalizationNode *Function::createBatchNormalization(
1608 llvm::StringRef name, NodeValue input, NodeValue beta, NodeValue scale,
1609 NodeValue mean, NodeValue var, unsigned_t channelIdx, float epsilon,
1610 float momentum) {
1611 return addNode(new BatchNormalizationNode(name, input, scale, beta, mean, var,
1612 channelIdx, epsilon, momentum));
1613 }
1614
createLayerNormalization(llvm::StringRef name,NodeValue input,NodeValue scale,NodeValue bias,float epsilon)1615 LayerNormalizationNode *Function::createLayerNormalization(llvm::StringRef name,
1616 NodeValue input,
1617 NodeValue scale,
1618 NodeValue bias,
1619 float epsilon) {
1620 return addNode(new LayerNormalizationNode(name, input, scale, bias, epsilon));
1621 }
1622
createBucketizeNode(llvm::StringRef name,NodeValue input,llvm::ArrayRef<float> boundaries)1623 BucketizeNode *Function::createBucketizeNode(llvm::StringRef name,
1624 NodeValue input,
1625 llvm::ArrayRef<float> boundaries) {
1626 auto OT = getParent()->uniqueType(ElemKind::Int32ITy, input.dims());
1627 return addNode(new BucketizeNode(name, OT, input, boundaries));
1628 }
1629
createLocalResponseNormalization(llvm::StringRef name,NodeValue input,unsigned_t halfWindowSize,float alpha,float beta,float k)1630 LocalResponseNormalizationNode *Function::createLocalResponseNormalization(
1631 llvm::StringRef name, NodeValue input, unsigned_t halfWindowSize,
1632 float alpha, float beta, float k) {
1633 // The output tensor is of the same shape as the input tensor.
1634 return addNode(new LocalResponseNormalizationNode(name, input, halfWindowSize,
1635 alpha, beta, k));
1636 }
1637
createModulo(llvm::StringRef name,NodeValue input,int64_t divisor,bool signFollowDivisor)1638 ModuloNode *Function::createModulo(llvm::StringRef name, NodeValue input,
1639 int64_t divisor, bool signFollowDivisor) {
1640 // The output tensor is of the same shape as the input tensor.
1641 auto OT = getParent()->uniqueType(*input.getType());
1642 return addNode(new ModuloNode(name, OT, input, divisor, signFollowDivisor));
1643 }
1644
createNot(llvm::StringRef name,NodeValue input)1645 NotNode *Function::createNot(llvm::StringRef name, NodeValue input) {
1646 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, input.dims());
1647 return addNode(new NotNode(name, OT, input));
1648 }
1649
1650 #define UNARY_ARITHMETIC_FUN_DEF(NODE_NAME_) \
1651 NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name, \
1652 NodeValue input) { \
1653 return create##NODE_NAME_(name, input.getType(), input); \
1654 } \
1655 NODE_NAME_##Node *Function::create##NODE_NAME_(llvm::StringRef name, \
1656 TypeRef T, NodeValue input) { \
1657 TypeRef OT = getParent()->uniqueType(*T); \
1658 return addNode(new NODE_NAME_##Node(name, OT, input)); \
1659 }
1660 UNARY_ARITHMETIC_FUN_DEF(Abs)
1661 UNARY_ARITHMETIC_FUN_DEF(Neg)
1662 UNARY_ARITHMETIC_FUN_DEF(Floor)
1663 UNARY_ARITHMETIC_FUN_DEF(Ceil)
1664 UNARY_ARITHMETIC_FUN_DEF(Round)
1665 UNARY_ARITHMETIC_FUN_DEF(Sqrt)
1666 UNARY_ARITHMETIC_FUN_DEF(Rsqrt)
1667 UNARY_ARITHMETIC_FUN_DEF(Reciprocal)
1668 UNARY_ARITHMETIC_FUN_DEF(Sin)
1669 UNARY_ARITHMETIC_FUN_DEF(Cos)
1670 #undef UNARY_ARITHMETIC_FUN_DEF
1671
1672 #define ARITHMETIC_FUN_DEF(NODE_NAME_) \
1673 NODE_NAME_##Node *Function::create##NODE_NAME_( \
1674 llvm::StringRef name, NodeValue LHS, NodeValue RHS) { \
1675 return create##NODE_NAME_(name, LHS.getType(), LHS, RHS); \
1676 } \
1677 NODE_NAME_##Node *Function::create##NODE_NAME_( \
1678 llvm::StringRef name, TypeRef T, NodeValue LHS, NodeValue RHS) { \
1679 DCHECK(LHS.dims() == RHS.dims()) \
1680 << "Invalid operand shapes " << LHS.dims() << " vs " << RHS.dims(); \
1681 TypeRef OT = getParent()->uniqueType(*T); \
1682 return addNode(new NODE_NAME_##Node(name, OT, LHS, RHS)); \
1683 }
1684 ARITHMETIC_FUN_DEF(Add);
1685 ARITHMETIC_FUN_DEF(Mul);
1686 ARITHMETIC_FUN_DEF(Sub);
1687 ARITHMETIC_FUN_DEF(Div);
1688 ARITHMETIC_FUN_DEF(Max);
1689 ARITHMETIC_FUN_DEF(Min);
1690 ARITHMETIC_FUN_DEF(Pow);
1691 #undef ARITHMETIC_FUN_DEF
1692
createAnd(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1693 AndNode *Function::createAnd(llvm::StringRef name, NodeValue LHS,
1694 NodeValue RHS) {
1695 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1696 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1697 return addNode(new AndNode(name, OT, LHS, RHS));
1698 }
1699
createOr(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1700 OrNode *Function::createOr(llvm::StringRef name, NodeValue LHS, NodeValue RHS) {
1701 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1702 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1703 return addNode(new OrNode(name, OT, LHS, RHS));
1704 }
1705
createXor(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1706 XorNode *Function::createXor(llvm::StringRef name, NodeValue LHS,
1707 NodeValue RHS) {
1708 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1709 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1710 return addNode(new XorNode(name, OT, LHS, RHS));
1711 }
1712
createCmpLTE(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1713 CmpLTENode *Function::createCmpLTE(llvm::StringRef name, NodeValue LHS,
1714 NodeValue RHS) {
1715 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1716 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1717 return addNode(new CmpLTENode(name, OT, LHS, RHS));
1718 }
1719
createCmpLT(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1720 CmpLTNode *Function::createCmpLT(llvm::StringRef name, NodeValue LHS,
1721 NodeValue RHS) {
1722 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1723 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1724 return addNode(new CmpLTNode(name, OT, LHS, RHS));
1725 }
1726
createCmpGTE(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1727 CmpLTENode *Function::createCmpGTE(llvm::StringRef name, NodeValue LHS,
1728 NodeValue RHS) {
1729 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1730 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1731 return addNode(new CmpLTENode(name, OT, RHS, LHS));
1732 }
1733
createCmpGT(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1734 CmpLTNode *Function::createCmpGT(llvm::StringRef name, NodeValue LHS,
1735 NodeValue RHS) {
1736 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1737 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1738 return addNode(new CmpLTNode(name, OT, RHS, LHS));
1739 }
1740
createCmpEQ(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1741 CmpEQNode *Function::createCmpEQ(llvm::StringRef name, NodeValue LHS,
1742 NodeValue RHS) {
1743 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1744 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1745 return addNode(new CmpEQNode(name, OT, LHS, RHS));
1746 }
1747
createCmpNEQ(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1748 CmpNEQNode *Function::createCmpNEQ(llvm::StringRef name, NodeValue LHS,
1749 NodeValue RHS) {
1750 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1751 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, LHS.dims());
1752 return addNode(new CmpNEQNode(name, OT, LHS, RHS));
1753 }
1754
createSquare(llvm::StringRef name,NodeValue input)1755 MulNode *Function::createSquare(llvm::StringRef name, NodeValue input) {
1756 return createMul(name, input, input);
1757 }
1758
createSquare(llvm::StringRef name,TypeRef outTy,NodeValue input)1759 MulNode *Function::createSquare(llvm::StringRef name, TypeRef outTy,
1760 NodeValue input) {
1761 return createMul(name, outTy, input, input);
1762 }
1763
createLeakyRELU(llvm::StringRef name,NodeValue input,float alpha)1764 PReluNode *Function::createLeakyRELU(llvm::StringRef name, NodeValue input,
1765 float alpha) {
1766 return createLeakyRELU(name, input.getType(), input, alpha);
1767 }
1768
createLeakyRELU(llvm::StringRef name,TypeRef outTy,NodeValue input,float alpha)1769 PReluNode *Function::createLeakyRELU(llvm::StringRef name, TypeRef outTy,
1770 NodeValue input, float alpha) {
1771 auto splatType = getParent()->uniqueType(*(input.getType()));
1772 SplatNode *splat = createSplat(name.str() + ".alpha", splatType, alpha);
1773 auto OT = getParent()->uniqueType(*outTy);
1774 return createPRELU(name, input, splat, OT);
1775 }
1776
createIsNaN(llvm::StringRef name,NodeValue input)1777 IsNaNNode *Function::createIsNaN(llvm::StringRef name, NodeValue input) {
1778 TypeRef OT = getParent()->uniqueType(ElemKind::BoolTy, input.dims());
1779 return addNode(new IsNaNNode(name, OT, input));
1780 }
1781
createReplaceNaN(llvm::StringRef name,NodeValue input,float value)1782 ReplaceNaNNode *Function::createReplaceNaN(llvm::StringRef name,
1783 NodeValue input, float value) {
1784 return addNode(new ReplaceNaNNode(name, input.getType(), input, value));
1785 }
1786
createPow(llvm::StringRef name,NodeValue base,float exp)1787 PowNode *Function::createPow(llvm::StringRef name, NodeValue base, float exp) {
1788 auto *SP = createSplat(name, base.getType(), exp);
1789 return createPow(name, base, SP);
1790 }
1791
createLog(llvm::StringRef name,NodeValue input,TypeRef outTy)1792 LogNode *Function::createLog(llvm::StringRef name, NodeValue input,
1793 TypeRef outTy) {
1794 return addNode(new LogNode(name, outTy ? outTy : input.getType(), input));
1795 }
1796
createExp(llvm::StringRef name,NodeValue input)1797 ExpNode *Function::createExp(llvm::StringRef name, NodeValue input) {
1798 return addNode(new ExpNode(name, input.getType(), input));
1799 }
1800
createExp(llvm::StringRef name,TypeRef outTy,NodeValue input)1801 ExpNode *Function::createExp(llvm::StringRef name, TypeRef outTy,
1802 NodeValue input) {
1803 return addNode(new ExpNode(name, outTy, input));
1804 }
1805
createLogit(llvm::StringRef name,NodeValue input,float eps)1806 LogitNode *Function::createLogit(llvm::StringRef name, NodeValue input,
1807 float eps) {
1808 return addNode(new LogitNode(name, input.getType(), input, eps));
1809 }
1810
createSelect(llvm::StringRef name,TypeRef outTy,NodeValue Cond,NodeValue LHS,NodeValue RHS)1811 SelectNode *Function::createSelect(llvm::StringRef name, TypeRef outTy,
1812 NodeValue Cond, NodeValue LHS,
1813 NodeValue RHS) {
1814 assert(LHS.dims() == RHS.dims() && "Invalid operand shapes");
1815 assert(LHS.dims() == Cond.dims() && "Invalid operand shapes");
1816 assert(LHS.dims() == outTy->dims() && "Invalid result shape");
1817 auto OT = getParent()->uniqueType(*outTy);
1818 return addNode(new SelectNode(name, OT, Cond, LHS, RHS));
1819 }
1820
createSelect(llvm::StringRef name,NodeValue Cond,NodeValue LHS,NodeValue RHS)1821 SelectNode *Function::createSelect(llvm::StringRef name, NodeValue Cond,
1822 NodeValue LHS, NodeValue RHS) {
1823 auto inDims = LHS.dims();
1824 assert(inDims.size() > 0);
1825 ShapeVector outDims(inDims.begin(), inDims.end());
1826 auto OT = getParent()->uniqueTypeWithNewShape(LHS.getType(), outDims);
1827 return createSelect(name, OT, Cond, LHS, RHS);
1828 }
1829
createSplat(llvm::StringRef name,TypeRef ty,float value)1830 SplatNode *Function::createSplat(llvm::StringRef name, TypeRef ty,
1831 float value) {
1832 return addNode(new SplatNode(name, getParent()->uniqueType(*ty), value));
1833 }
1834
createTouch(llvm::StringRef name,TypeRef ty)1835 TouchNode *Function::createTouch(llvm::StringRef name, TypeRef ty) {
1836 return addNode(new TouchNode(name, getParent()->uniqueType(*ty)));
1837 }
1838
createMatMul(llvm::StringRef name,TypeRef outTy,NodeValue lhs,NodeValue rhs)1839 MatMulNode *Function::createMatMul(llvm::StringRef name, TypeRef outTy,
1840 NodeValue lhs, NodeValue rhs) {
1841 return addNode(
1842 new MatMulNode(name, getParent()->uniqueType(*outTy), lhs, rhs));
1843 }
1844
createMatMul(llvm::StringRef name,NodeValue lhs,NodeValue rhs)1845 MatMulNode *Function::createMatMul(llvm::StringRef name, NodeValue lhs,
1846 NodeValue rhs) {
1847 auto LT = lhs.getType();
1848 auto RT = rhs.getType();
1849 auto LDims = LT->dims();
1850 auto RDims = RT->dims();
1851 assert(lhs.getType()->getElementType() == rhs.getType()->getElementType());
1852
1853 auto ty =
1854 getParent()->uniqueTypeWithNewShape(lhs.getType(), {LDims[0], RDims[1]});
1855 return createMatMul(name, ty, lhs, rhs);
1856 }
1857
createBatchMatMul(llvm::StringRef name,NodeValue LHS,NodeValue RHS)1858 BatchMatMulNode *Function::createBatchMatMul(llvm::StringRef name,
1859 NodeValue LHS, NodeValue RHS) {
1860 const size_t numDimsRHS = RHS.dims().size();
1861 assert(LHS.dims().size() == 3 && "LHS must be 3 dimensional.");
1862 assert((numDimsRHS == 2 || numDimsRHS == 3) &&
1863 "RHS must be 2 or 3 dimensional.");
1864
1865 // If necessary, expand the RHS input to be 3D by adding initial leading
1866 // dim.
1867 if (numDimsRHS == 2) {
1868 RHS = createExpandDims(name.str() + ".reshapeRHS", RHS, {0});
1869 }
1870 // If necessary, Tile the RHS input so it matches the numBatches of LHS.
1871 if (RHS.dims()[0] == 1 && LHS.dims()[0] != 1) {
1872 RHS = createTile(name.str() + ".tileRHS", RHS, LHS.dims()[0], /*axis */ 0);
1873 }
1874
1875 // LHS = {numBatches, N, M}
1876 // RHS = {numBatches, M, P}
1877 // Result = {numBatches, N, P}
1878 const dim_t numBatches = LHS.dims()[0];
1879 const dim_t N = LHS.dims()[1];
1880 const dim_t M = LHS.dims()[2];
1881 (void)M;
1882 const dim_t P = RHS.dims()[2];
1883 assert((RHS.dims()[0] == numBatches) && "Batch sizes are invalid.");
1884 assert((RHS.dims()[1] == M) && "Batch matmul dimensions are invalid.");
1885
1886 auto OT =
1887 getParent()->uniqueTypeWithNewShape(LHS.getType(), {numBatches, N, P});
1888 return addNode(new BatchMatMulNode(name, OT, LHS, RHS));
1889 }
1890
1891 BatchedReduceAddNode *
createBatchedReduceAdd(llvm::StringRef name,TypeRef outTy,NodeValue batch,llvm::ArrayRef<unsigned_t> axes)1892 Function::createBatchedReduceAdd(llvm::StringRef name, TypeRef outTy,
1893 NodeValue batch,
1894 llvm::ArrayRef<unsigned_t> axes) {
1895 assert(axes.size() == 1 && "Only supporting single reduction for now.");
1896 auto axis = axes[0];
1897
1898 // Calculate the expected total number of elements in the output tensor
1899 // based on the number of elements in the batch divided by the axis
1900 // dimension.
1901 const size_t outNumElements = batch.getType()->size() / batch.dims()[axis];
1902 (void)outNumElements;
1903 assert(outTy->size() == outNumElements &&
1904 "Incorrect number of elements in the output type.");
1905 auto OT = getParent()->uniqueType(*outTy);
1906 return addNode(new BatchedReduceAddNode(name, OT, batch, axis));
1907 }
1908
1909 BatchedReduceAddNode *
createBatchedReduceAdd(llvm::StringRef name,NodeValue batch,llvm::ArrayRef<unsigned_t> axes)1910 Function::createBatchedReduceAdd(llvm::StringRef name, NodeValue batch,
1911 llvm::ArrayRef<unsigned_t> axes) {
1912 auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
1913 auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims);
1914 return createBatchedReduceAdd(name, OT, batch, axes);
1915 }
1916
1917 BatchedReduceMeanNode *
createBatchedReduceMean(llvm::StringRef name,TypeRef outTy,NodeValue batch,llvm::ArrayRef<unsigned_t> axes)1918 Function::createBatchedReduceMean(llvm::StringRef name, TypeRef outTy,
1919 NodeValue batch,
1920 llvm::ArrayRef<unsigned_t> axes) {
1921 auto OT = getParent()->uniqueType(*outTy);
1922 return addNode(new BatchedReduceMeanNode(name, OT, batch, axes));
1923 }
1924
1925 BatchedReduceMeanNode *
createBatchedReduceMean(llvm::StringRef name,NodeValue batch,llvm::ArrayRef<unsigned_t> axes)1926 Function::createBatchedReduceMean(llvm::StringRef name, NodeValue batch,
1927 llvm::ArrayRef<unsigned_t> axes) {
1928 // Create new shape with specified dimensions either reduced or removed.
1929 auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
1930 auto OT = getParent()->uniqueTypeWithNewShape(batch.getType(), outDims);
1931 return createBatchedReduceMean(name, OT, batch, axes);
1932 }
1933
1934 BatchedReduceMinNode *
createBatchedReduceMin(llvm::StringRef name,NodeValue batch,llvm::ArrayRef<unsigned_t> axes)1935 Function::createBatchedReduceMin(llvm::StringRef name, NodeValue batch,
1936 llvm::ArrayRef<unsigned_t> axes) {
1937 // Create new shape with specified dimensions either reduced or removed.
1938 auto outDims = getNewShapeWithoutAxes(batch.dims(), axes);
1939 auto OT = getParent()->uniqueType(batch.getType()->getElementType(), outDims);
1940 return addNode(new BatchedReduceMinNode(name, OT, batch, axes));
1941 }
1942
createBatchedAdd(llvm::StringRef name,NodeValue batch,NodeValue slice)1943 BatchedAddNode *Function::createBatchedAdd(llvm::StringRef name,
1944 NodeValue batch, NodeValue slice) {
1945 return addNode(new BatchedAddNode(name, batch.getType(), batch, slice));
1946 }
1947
createBatchedAdd(llvm::StringRef name,TypeRef outTy,NodeValue batch,NodeValue slice)1948 BatchedAddNode *Function::createBatchedAdd(llvm::StringRef name, TypeRef outTy,
1949 NodeValue batch, NodeValue slice) {
1950 return addNode(
1951 new BatchedAddNode(name, getParent()->uniqueType(*outTy), batch, slice));
1952 }
1953
createCumSum(llvm::StringRef name,NodeValue input,bool exclusive,bool reverse)1954 CumSumNode *Function::createCumSum(llvm::StringRef name, NodeValue input,
1955 bool exclusive, bool reverse) {
1956 return addNode(
1957 new CumSumNode(name, input.getType(), input, exclusive, reverse));
1958 }
1959
createLengthsSum(llvm::StringRef name,NodeValue data,NodeValue lengths)1960 LengthsSumNode *Function::createLengthsSum(llvm::StringRef name, NodeValue data,
1961 NodeValue lengths) {
1962 ShapeVector outDims(data.dims().begin(), data.dims().end());
1963 outDims[0] = lengths.dims()[0];
1964 auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
1965 return addNode(new LengthsSumNode(name, outTy, data, lengths));
1966 }
1967
1968 SparseLengthsSumNode *
createSparseLengthsSum(llvm::StringRef name,NodeValue data,NodeValue indices,NodeValue lengths,LengthsMode lengthsMode,float avgLength)1969 Function::createSparseLengthsSum(llvm::StringRef name, NodeValue data,
1970 NodeValue indices, NodeValue lengths,
1971 LengthsMode lengthsMode, float avgLength) {
1972 auto inDims = data.dims();
1973 ShapeVector outDims(inDims.begin(), inDims.end());
1974 outDims[0] = lengths.dims()[0];
1975 auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
1976 return addNode(new SparseLengthsSumNode(name, outTy, data, indices, lengths,
1977 lengthsMode, avgLength));
1978 }
1979
createSparseLengthsWeightedSum(llvm::StringRef name,NodeValue data,NodeValue weights,NodeValue indices,NodeValue lengths,LengthsMode lengthsMode,float avgLength)1980 SparseLengthsWeightedSumNode *Function::createSparseLengthsWeightedSum(
1981 llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
1982 NodeValue lengths, LengthsMode lengthsMode, float avgLength) {
1983 auto inDims = data.dims();
1984 ShapeVector outDims(inDims.begin(), inDims.end());
1985 outDims[0] = lengths.dims()[0];
1986 auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
1987 return addNode(new SparseLengthsWeightedSumNode(
1988 name, outTy, data, weights, indices, lengths, lengthsMode, avgLength));
1989 }
1990
createSparseLengthsWeightedSum(llvm::StringRef name,TypeRef outTy,NodeValue data,NodeValue weights,NodeValue indices,NodeValue lengths,LengthsMode lengthsMode,float avgLength)1991 SparseLengthsWeightedSumNode *Function::createSparseLengthsWeightedSum(
1992 llvm::StringRef name, TypeRef outTy, NodeValue data, NodeValue weights,
1993 NodeValue indices, NodeValue lengths, LengthsMode lengthsMode,
1994 float avgLength) {
1995 return addNode(new SparseLengthsWeightedSumNode(
1996 name, outTy, data, weights, indices, lengths, lengthsMode, avgLength));
1997 }
1998
1999 RowwiseQuantizedSparseLengthsWeightedSumNode *
createRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,Storage * data,Constant * scales,Constant * offsets,NodeValue weights,NodeValue indices,NodeValue lengths,ElemKind precision,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2000 Function::createRowwiseQuantizedSparseLengthsWeightedSum(
2001 llvm::StringRef name, Storage *data, Constant *scales, Constant *offsets,
2002 NodeValue weights, NodeValue indices, NodeValue lengths, ElemKind precision,
2003 bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2004 auto inDims = data->dims();
2005 ShapeVector outDims(inDims.begin(), inDims.end());
2006 outDims[0] = lengths.dims()[0];
2007 auto outTy = getParent()->uniqueType(precision, outDims);
2008 return addNode(new RowwiseQuantizedSparseLengthsWeightedSumNode(
2009 name, outTy, data, scales, offsets, weights, indices, lengths,
2010 useFP16Accumulation, lengthsMode, avgLength));
2011 }
2012
2013 RowwiseQuantizedSparseLengthsWeightedSumNode *
createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name,Storage * data,Constant * scales,Constant * offsets,NodeValue indices,NodeValue lengths,ElemKind precision,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2014 Function::createRowwiseQuantizedSparseLengthsSum(
2015 llvm::StringRef name, Storage *data, Constant *scales, Constant *offsets,
2016 NodeValue indices, NodeValue lengths, ElemKind precision,
2017 bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2018 auto ty = getParent()->uniqueType(precision, {indices.dims()[0]});
2019 auto ones = createSplat(name.str() + ".ones", ty, 1.0);
2020 return createRowwiseQuantizedSparseLengthsWeightedSum(
2021 name, data, scales, offsets, ones, indices, lengths, precision,
2022 useFP16Accumulation, lengthsMode, avgLength);
2023 }
2024
2025 /// Helper to create a RowwiseQuantizedSparseLengthsWeightedSumNode in the
2026 /// Function \p F with \p name, using \ data, \p weights, \p indices, and \p
2027 /// lengths as inputs. The provided float data in \p Tensor is rowwise
2028 /// quantized, creating Constants for the rowwise quantized data as well as
2029 /// Scales and Offsets, in the Module containing \p F.
2030 static RowwiseQuantizedSparseLengthsWeightedSumNode *
quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(Function * F,llvm::StringRef name,Tensor & data,NodeValue weights,NodeValue indices,NodeValue lengths,quantization::Schema schema,ElemKind precision,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2031 quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(
2032 Function *F, llvm::StringRef name, Tensor &data, NodeValue weights,
2033 NodeValue indices, NodeValue lengths, quantization::Schema schema,
2034 ElemKind precision, bool useFP16Accumulation, LengthsMode lengthsMode,
2035 float avgLength) {
2036 auto inDims = data.dims();
2037
2038 // Note: In rwqData, we are using a quantized type, however the scale/offset
2039 // are set to dummy values 0.0/0. This is because the actually used
2040 // scale/offset come from dataScales and dataOffsets.
2041 Constant *rwqData = F->getParent()->createConstant(ElemKind::UInt8QTy, inDims,
2042 0.0, 0, "data");
2043 Constant *dataScales =
2044 F->getParent()->createConstant(precision, {inDims[0]}, "dataScales");
2045 Constant *dataOffsets =
2046 F->getParent()->createConstant(precision, {inDims[0]}, "dataOffsets");
2047
2048 // Note: Using floating point offset here as that is what RWQ-SLWS expects.
2049 switch (precision) {
2050 case ElemKind::FloatTy:
2051 quantization::tensorRowwiseQuantization<float, float, uint8_t>(
2052 data, rwqData->getPayloadMutable(), dataScales->getPayloadMutable(),
2053 dataOffsets->getPayloadMutable(), schema);
2054 break;
2055 case ElemKind::Float16Ty:
2056 quantization::tensorRowwiseQuantization<float16_t, float16_t, uint8_t>(
2057 data, rwqData->getPayloadMutable(), dataScales->getPayloadMutable(),
2058 dataOffsets->getPayloadMutable(), schema);
2059 break;
2060 default:
2061 LOG(FATAL) << "Unsupported precision for RWQ-SLWS.";
2062 }
2063 return F->createRowwiseQuantizedSparseLengthsWeightedSum(
2064 name, rwqData, dataScales, dataOffsets, weights, indices, lengths,
2065 precision, useFP16Accumulation, lengthsMode, avgLength);
2066 }
2067
2068 RowwiseQuantizedSparseLengthsWeightedSumNode *
createRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,Tensor & data,NodeValue weights,NodeValue indices,NodeValue lengths,quantization::Schema schema,ElemKind precision,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2069 Function::createRowwiseQuantizedSparseLengthsWeightedSum(
2070 llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
2071 NodeValue lengths, quantization::Schema schema, ElemKind precision,
2072 bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2073 return quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(
2074 this, name, data, weights, indices, lengths, schema, precision,
2075 useFP16Accumulation, lengthsMode, avgLength);
2076 }
2077
2078 RowwiseQuantizedSparseLengthsWeightedSumNode *
createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name,Tensor & data,NodeValue indices,NodeValue lengths,quantization::Schema schema,ElemKind precision,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2079 Function::createRowwiseQuantizedSparseLengthsSum(
2080 llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
2081 quantization::Schema schema, ElemKind precision, bool useFP16Accumulation,
2082 LengthsMode lengthsMode, float avgLength) {
2083 auto ty = getParent()->uniqueType(precision, {indices.dims()[0]});
2084 auto ones = createSplat(name.str() + ".ones", ty, 1.0);
2085 return quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(
2086 this, name, data, ones, indices, lengths, schema, precision,
2087 useFP16Accumulation, lengthsMode, avgLength);
2088 }
2089
2090 /// Helper used to get specific output type required for
2091 /// createRowwiseQuantizedSparseLengthsSum,
2092 /// createRowwiseQuantizedSparseLengthsWeightedSum, and
2093 /// EmbeddingBagByteRowwiseOffsets. Function \p F is used to get the specific
2094 /// type, using inputs \p data and \p segmentsDim to compute output dimensions.
2095 static TypeRef
getOutputTypeOfFusedRowwiseQuantizedSLS(Function * F,NodeValue data,llvm::ArrayRef<dim_t> segmentsDim)2096 getOutputTypeOfFusedRowwiseQuantizedSLS(Function *F, NodeValue data,
2097 llvm::ArrayRef<dim_t> segmentsDim) {
2098 ShapeVector outDims(data.dims().begin(), data.dims().end());
2099 outDims[0] = segmentsDim[0];
2100 // The output column count is the same as the input column count, but
2101 // without the extra bytes for the fused scale/offset, as the output is not
2102 // fused.
2103 CHECK(isFusedQuantizedElemKind(data.getElementType()))
2104 << "Must use a fused ElemKind for data.";
2105 outDims[1] -= 2 * ((data.getElementType() == ElemKind::UInt8FusedQTy)
2106 ? sizeof(float)
2107 : sizeof(float16_t));
2108 // If using 4-bit quantization, then the input data has packed two 4-bit
2109 // elements into one byte, so we need to double the outDims.
2110 if (data.getElementType() == ElemKind::UInt4FusedFP16QTy) {
2111 outDims[1] *= 2;
2112 }
2113 const ElemKind outputK = (data.getElementType() == ElemKind::UInt8FusedQTy)
2114 ? ElemKind::FloatTy
2115 : ElemKind::Float16Ty;
2116 return F->getParent()->uniqueType(outputK, outDims);
2117 }
2118
2119 FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
createFusedRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,NodeValue data,NodeValue weights,NodeValue indices,NodeValue lengths,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2120 Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum(
2121 llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
2122 NodeValue lengths, bool useFP16Accumulation, LengthsMode lengthsMode,
2123 float avgLength) {
2124 auto outTy =
2125 getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, lengths.dims());
2126 return addNode(new FusedRowwiseQuantizedSparseLengthsWeightedSumNode(
2127 name, outTy, data, weights, indices, lengths, useFP16Accumulation,
2128 lengthsMode, avgLength));
2129 }
2130
2131 FusedRowwiseQuantizedSparseLengthsSumNode *
createFusedRowwiseQuantizedSparseLengthsSum(llvm::StringRef name,Storage * data,NodeValue indices,NodeValue lengths,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2132 Function::createFusedRowwiseQuantizedSparseLengthsSum(
2133 llvm::StringRef name, Storage *data, NodeValue indices, NodeValue lengths,
2134 bool useFP16Accumulation, LengthsMode lengthsMode, float avgLength) {
2135 auto outTy =
2136 getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, lengths.dims());
2137 return addNode(new FusedRowwiseQuantizedSparseLengthsSumNode(
2138 name, outTy, data, indices, lengths, useFP16Accumulation, lengthsMode,
2139 avgLength));
2140 }
2141
2142 /// Helper to get quantized data required for
2143 /// RowwiseQuantizedSparseLengthsWeightedSumNode and
2144 /// RowwiseQuantizedSparseLengthsSumNode. Function \p F uses float Tensor \p
2145 /// data to create a rowwise qunatized Constant \p rwqData, which contains fused
2146 /// scales and offsets.
quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(Function * F,Tensor & data,ElemKind precision)2147 static Constant *quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2148 Function *F, Tensor &data, ElemKind precision) {
2149 // For fused rowwise quantization, we must have a two-dimensional input. If
2150 // passed in a single dimensional data Tensor then add an extra dimension.
2151 const auto fDims = flattenCdr(data.dims());
2152 Tensor fData = data.getUnowned({fDims.first, fDims.second});
2153
2154 // Note: In rwqData, we are using a quantized type, however the scale/offset
2155 // are set to dummy values 0.0/0. This is because the actually used
2156 // scale/offset are fused inline with each row. Also, we expand the second
2157 // dimension to include space for the scale/offset, each 4 bytes
2158 // (float/int32_t).
2159 switch (precision) {
2160 case ElemKind::UInt8FusedQTy: {
2161 Constant *rwqData = F->getParent()->createConstant(
2162 precision, {fDims.first, fDims.second + 2 * (dim_t)sizeof(float)}, 0.0,
2163 0, "data");
2164 quantization::tensorFusedRowwiseQuantization<float>(
2165 fData, rwqData->getPayloadMutable());
2166 return rwqData;
2167 }
2168 case ElemKind::UInt8FusedFP16QTy: {
2169 Constant *rwqData = F->getParent()->createConstant(
2170 precision, {fDims.first, fDims.second + 2 * (dim_t)sizeof(float16_t)},
2171 0.0, 0, "data");
2172 quantization::tensorFusedRowwiseQuantization<float16_t>(
2173 fData, rwqData->getPayloadMutable());
2174 return rwqData;
2175 }
2176 case ElemKind::UInt4FusedFP16QTy: {
2177 // We pack 4-bit values into bytes, so given the input size in float we
2178 // divide by two and take the ceiling to make sure we have enough space for
2179 // all elements.
2180 const dim_t outerDim =
2181 std::ceil(((float)fDims.second) / 2) + 2 * sizeof(float16_t);
2182 Constant *rwqData = F->getParent()->createConstant(
2183 precision, {fDims.first, outerDim}, 0.0, 0, "data");
2184 quantization::tensorFusedRowwiseQuantization<float16_t>(
2185 fData, rwqData->getPayloadMutable());
2186 return rwqData;
2187 }
2188 default:
2189 llvm_unreachable("Invalid type for FusedRowwiswQuantization.");
2190 }
2191 }
2192
2193 FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
createFusedRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,Tensor & data,NodeValue weights,NodeValue indices,NodeValue lengths,ElemKind fusedElemKind,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2194 Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum(
2195 llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
2196 NodeValue lengths, ElemKind fusedElemKind, bool useFP16Accumulation,
2197 LengthsMode lengthsMode, float avgLength) {
2198 Constant *rwqData =
2199 quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2200 this, data, fusedElemKind);
2201 return createFusedRowwiseQuantizedSparseLengthsWeightedSum(
2202 name, rwqData, weights, indices, lengths, useFP16Accumulation,
2203 lengthsMode, avgLength);
2204 }
2205
2206 FusedRowwiseQuantizedSparseLengthsSumNode *
createFusedRowwiseQuantizedSparseLengthsSum(llvm::StringRef name,Tensor & data,NodeValue indices,NodeValue lengths,ElemKind fusedElemKind,bool useFP16Accumulation,LengthsMode lengthsMode,float avgLength)2207 Function::createFusedRowwiseQuantizedSparseLengthsSum(
2208 llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
2209 ElemKind fusedElemKind, bool useFP16Accumulation, LengthsMode lengthsMode,
2210 float avgLength) {
2211 Constant *rwqData =
2212 quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2213 this, data, fusedElemKind);
2214 return this->createFusedRowwiseQuantizedSparseLengthsSum(
2215 name, rwqData, indices, lengths, useFP16Accumulation, lengthsMode,
2216 avgLength);
2217 }
2218
2219 EmbeddingBagNode *
createEmbeddingBag(llvm::StringRef name,NodeValue data,NodeValue weights,NodeValue indices,NodeValue offsets,bool hasEndOffset,LengthsMode lengthsMode,float avgLength)2220 Function::createEmbeddingBag(llvm::StringRef name, NodeValue data,
2221 NodeValue weights, NodeValue indices,
2222 NodeValue offsets, bool hasEndOffset,
2223 LengthsMode lengthsMode, float avgLength) {
2224 auto inDims = data.dims();
2225 ShapeVector outDims(inDims.begin(), inDims.end());
2226 outDims[0] = hasEndOffset ? offsets.dims()[0] - 1 : offsets.dims()[0];
2227 auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
2228 return addNode(new EmbeddingBagNode(name, outTy, data, weights, indices,
2229 offsets, hasEndOffset, lengthsMode,
2230 avgLength));
2231 }
2232
2233 EmbeddingBagByteRowwiseOffsetsNode *
createEmbeddingBagByteRowwiseOffsets(llvm::StringRef name,Tensor & data,NodeValue weights,NodeValue indices,NodeValue offsets,ElemKind fusedElemKind,bool useFP16Accumulation,bool hasEndOffset,LengthsMode lengthsMode,float avgLength)2234 Function::createEmbeddingBagByteRowwiseOffsets(
2235 llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
2236 NodeValue offsets, ElemKind fusedElemKind, bool useFP16Accumulation,
2237 bool hasEndOffset, LengthsMode lengthsMode, float avgLength) {
2238 Constant *rwqData =
2239 quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
2240 this, data, fusedElemKind);
2241 return createEmbeddingBagByteRowwiseOffsets(
2242 name, rwqData, weights, indices, offsets, useFP16Accumulation,
2243 hasEndOffset, lengthsMode, avgLength);
2244 }
2245
2246 EmbeddingBagByteRowwiseOffsetsNode *
createEmbeddingBagByteRowwiseOffsets(llvm::StringRef name,NodeValue data,NodeValue weights,NodeValue indices,NodeValue offsets,bool useFP16Accumulation,bool hasEndOffset,LengthsMode lengthsMode,float avgLength)2247 Function::createEmbeddingBagByteRowwiseOffsets(
2248 llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
2249 NodeValue offsets, bool useFP16Accumulation, bool hasEndOffset,
2250 LengthsMode lengthsMode, float avgLength) {
2251 std::vector<dim_t> segmentDims(offsets.dims().begin(), offsets.dims().end());
2252 // If hasEndOffset the last offset is just for marking the end of the last
2253 // segment.
2254 if (hasEndOffset) {
2255 segmentDims[0] -= 1;
2256 }
2257 auto outTy = getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, segmentDims);
2258 return addNode(new EmbeddingBagByteRowwiseOffsetsNode(
2259 name, outTy, data, weights, indices, offsets, useFP16Accumulation,
2260 hasEndOffset, lengthsMode, avgLength));
2261 }
2262
createLengthsToRanges(llvm::StringRef name,NodeValue lengths)2263 LengthsToRangesNode *Function::createLengthsToRanges(llvm::StringRef name,
2264 NodeValue lengths) {
2265 ShapeVector outDims({lengths.dims()[0], 2});
2266 auto outTy = getParent()->uniqueTypeWithNewShape(lengths.getType(), outDims);
2267 return addNode(new LengthsToRangesNode(name, outTy, lengths));
2268 }
2269
2270 LengthsRangeFillNode *
createLengthsRangeFill(llvm::StringRef name,NodeValue lengths,unsigned_t maxOutputSize)2271 Function::createLengthsRangeFill(llvm::StringRef name, NodeValue lengths,
2272 unsigned_t maxOutputSize) {
2273 auto outTy =
2274 getParent()->uniqueTypeWithNewShape(lengths.getType(), {maxOutputSize});
2275 return addNode(new LengthsRangeFillNode(name, outTy, lengths));
2276 }
2277
createSparseToDense(llvm::StringRef name,NodeValue indices,NodeValue values,NodeValue dataToInferDim)2278 SparseToDenseNode *Function::createSparseToDense(llvm::StringRef name,
2279 NodeValue indices,
2280 NodeValue values,
2281 NodeValue dataToInferDim) {
2282 // The dimensions of the output are the same as the values tensor except for
2283 // the first dimension, which should match that of dataToInferDim.
2284 ShapeVector outDims(values.dims().begin(), values.dims().end());
2285 outDims[0] = dataToInferDim.dims()[0];
2286 auto outTy = getParent()->uniqueTypeWithNewShape(values.getType(), outDims);
2287 return addNode(new SparseToDenseNode(name, outTy, indices, values));
2288 }
2289
createSparseToDenseMask(llvm::StringRef name,NodeValue indices,NodeValue values,NodeValue defaultValue,NodeValue lengths,llvm::ArrayRef<dim_t> mask)2290 SparseToDenseMaskNode *Function::createSparseToDenseMask(
2291 llvm::StringRef name, NodeValue indices, NodeValue values,
2292 NodeValue defaultValue, NodeValue lengths, llvm::ArrayRef<dim_t> mask) {
2293 auto lengthsDims = lengths.dims();
2294 auto valueDims = defaultValue.dims();
2295 ShapeVector outDims = {(dim_t)mask.size()};
2296 // If lengths is 0-dimensional tensor, then there is no batch dimension.
2297 if (lengthsDims.size() > 0) {
2298 outDims.insert(outDims.begin(), lengthsDims[0]);
2299 }
2300 outDims.insert(outDims.end(), valueDims.begin(), valueDims.end());
2301 auto outTy = getParent()->uniqueTypeWithNewShape(values.getType(), outDims);
2302 return addNode(new SparseToDenseMaskNode(name, outTy, indices, values,
2303 defaultValue, lengths, mask));
2304 }
2305
createSave(llvm::StringRef name,NodeValue input)2306 SaveNode *Function::createSave(llvm::StringRef name, NodeValue input) {
2307 auto *dest = getParent()->createPlaceholder(input.getType(), name, false);
2308 return createSave(name, input, dest);
2309 }
2310
createSave(llvm::StringRef name,NodeValue input,Placeholder * output,bool skipSuffix)2311 SaveNode *Function::createSave(llvm::StringRef name, NodeValue input,
2312 Placeholder *output, bool skipSuffix) {
2313 return addNode(new SaveNode(skipSuffix ? name.str() : (name + "_save").str(),
2314 input, output));
2315 }
2316
2317 QuantizationProfileNode *
createQuantizationProfile(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t numHistogramBins)2318 Function::createQuantizationProfile(PlaceholderBindings &bindings,
2319 llvm::StringRef name, NodeValue input,
2320 dim_t numHistogramBins) {
2321 auto *histogram = getParent()->createPlaceholder(
2322 ElemKind::FloatTy, {numHistogramBins}, "histogram_" + name.str(), false);
2323 bindings.allocate(histogram)->zero();
2324 // Intermediate data used for histogram calculations.
2325 // Min tensor value seen so far is kept on the first position.
2326 // Max tensor value seen so far is kept on the second position.
2327 auto *computationInfoPH = getParent()->createPlaceholder(
2328 ElemKind::FloatTy, {2}, "CI_" + name.str(), false);
2329 bindings.allocate(computationInfoPH);
2330 auto *computationInfoTensor = bindings.get(computationInfoPH);
2331 auto handle = computationInfoTensor->getHandle<float>();
2332 handle.raw(0) = std::numeric_limits<float>::max();
2333 handle.raw(1) = std::numeric_limits<float>::lowest();
2334
2335 return addNode(new QuantizationProfileNode(
2336 "QI_" + name.str(), input, histogram, computationInfoPH,
2337 input.getNode()->getName().str(), input.getResNo()));
2338 }
2339
2340 IntLookupTableNode *
createIntLookupTable(llvm::StringRef name,NodeValue input,llvm::ArrayRef<int8_t> initValues,TypeRef outTy)2341 Function::createIntLookupTable(llvm::StringRef name, NodeValue input,
2342 llvm::ArrayRef<int8_t> initValues,
2343 TypeRef outTy) {
2344 auto *mapping = getParent()->createConstant(
2345 ElemKind::Int8QTy, {(dim_t)initValues.size()}, outTy->getScale(),
2346 outTy->getOffset(), "mapping");
2347 mapping->getHandle<int8_t>() = initValues;
2348
2349 return addNode(new IntLookupTableNode(name, outTy, input, mapping));
2350 }
2351
createIntTanh(llvm::StringRef name,NodeValue input,TypeRef outTy)2352 IntLookupTableNode *Function::createIntTanh(llvm::StringRef name,
2353 NodeValue input, TypeRef outTy) {
2354 static int8_t mapping[] = {
2355 -128, -127, -126, -126, -126, -126, -126, -126, -126, -126, -126, -126,
2356 -126, -126, -126, -126, -126, -126, -126, -126, -125, -125, -125, -125,
2357 -125, -125, -125, -125, -125, -125, -125, -124, -124, -124, -124, -124,
2358 -124, -124, -123, -123, -123, -123, -123, -123, -122, -122, -122, -122,
2359 -121, -121, -121, -120, -120, -120, -120, -119, -119, -118, -118, -118,
2360 -117, -117, -116, -116, -115, -115, -114, -114, -113, -112, -112, -111,
2361 -110, -109, -109, -108, -107, -106, -105, -104, -103, -102, -101, -100,
2362 -99, -98, -96, -95, -94, -92, -91, -89, -88, -86, -85, -83,
2363 -81, -79, -77, -76, -74, -72, -69, -67, -65, -63, -61, -58,
2364 -56, -53, -51, -48, -46, -43, -41, -38, -35, -32, -29, -27,
2365 -24, -21, -18, -15, -12, -9, -6, -3, 0, 3, 6, 9,
2366 12, 15, 18, 21, 24, 27, 29, 32, 35, 38, 41, 43,
2367 46, 48, 51, 53, 56, 58, 61, 63, 65, 67, 69, 72,
2368 74, 76, 77, 79, 81, 83, 85, 86, 88, 89, 91, 92,
2369 94, 95, 96, 98, 99, 100, 101, 102, 103, 104, 105, 106,
2370 107, 108, 109, 109, 110, 111, 112, 112, 113, 114, 114, 115,
2371 115, 116, 116, 117, 117, 118, 118, 118, 119, 119, 120, 120,
2372 120, 120, 121, 121, 121, 122, 122, 122, 122, 123, 123, 123,
2373 123, 123, 123, 124, 124, 124, 124, 124, 124, 124, 125, 125,
2374 125, 125, 125, 125, 125, 125, 125, 125, 125, 126, 126, 126,
2375 126, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126,
2376 126, 126, 126, 127};
2377
2378 return createIntLookupTable(name, input, mapping, outTy);
2379 }
2380
createIntSigmoid(llvm::StringRef name,NodeValue input,TypeRef outTy)2381 IntLookupTableNode *Function::createIntSigmoid(llvm::StringRef name,
2382 NodeValue input, TypeRef outTy) {
2383 static int8_t mapping[] = {
2384 -128, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127,
2385 -127, -127, -127, -127, -127, -127, -127, -126, -126, -126, -126, -126,
2386 -126, -126, -126, -126, -126, -126, -125, -125, -125, -125, -125, -125,
2387 -125, -125, -124, -124, -124, -124, -124, -123, -123, -123, -123, -122,
2388 -122, -122, -122, -121, -121, -121, -120, -120, -120, -119, -119, -118,
2389 -118, -118, -117, -117, -116, -115, -115, -114, -114, -113, -112, -112,
2390 -111, -110, -109, -109, -108, -107, -106, -105, -104, -103, -102, -101,
2391 -99, -98, -97, -96, -94, -93, -91, -90, -88, -87, -85, -83,
2392 -82, -80, -78, -76, -74, -72, -70, -68, -66, -63, -61, -59,
2393 -56, -54, -51, -49, -46, -44, -41, -38, -36, -33, -30, -27,
2394 -24, -21, -18, -15, -12, -9, -6, -3, -1, 2, 5, 8,
2395 11, 14, 17, 20, 23, 26, 29, 32, 35, 37, 40, 43,
2396 45, 48, 50, 53, 55, 58, 60, 62, 65, 67, 69, 71,
2397 73, 75, 77, 79, 81, 82, 84, 86, 87, 89, 90, 92,
2398 93, 95, 96, 97, 98, 100, 101, 102, 103, 104, 105, 106,
2399 107, 108, 108, 109, 110, 111, 111, 112, 113, 113, 114, 114,
2400 115, 116, 116, 117, 117, 117, 118, 118, 119, 119, 119, 120,
2401 120, 120, 121, 121, 121, 121, 122, 122, 122, 122, 123, 123,
2402 123, 123, 123, 124, 124, 124, 124, 124, 124, 124, 124, 125,
2403 125, 125, 125, 125, 125, 125, 125, 125, 125, 125, 126, 126,
2404 126, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126,
2405 126, 126, 126, 127};
2406
2407 return createIntLookupTable(name, input, mapping, outTy);
2408 }
2409
createTopK(llvm::StringRef name,NodeValue input,unsigned_t k,ElemKind outIndicesTyKind)2410 TopKNode *Function::createTopK(llvm::StringRef name, NodeValue input,
2411 unsigned_t k, ElemKind outIndicesTyKind) {
2412 auto inDims = input.dims();
2413 assert(inDims.size() > 0);
2414 assert(k <= inDims.back());
2415 ShapeVector outDims(inDims.begin(), inDims.end());
2416 outDims.back() = k;
2417 auto OT = getParent()->uniqueTypeWithNewShape(input.getType(), outDims);
2418 return addNode(new TopKNode(
2419 name, OT, getParent()->uniqueType(outIndicesTyKind, outDims), input, k));
2420 }
2421
createTopK(llvm::StringRef name,NodeValue input,unsigned_t k)2422 TopKNode *Function::createTopK(llvm::StringRef name, NodeValue input,
2423 unsigned_t k) {
2424 return createTopK(name, input, k, ElemKind::Int64ITy);
2425 }
2426
createArgMax(llvm::StringRef name,NodeValue input,unsigned_t axis,bool keepDims,ElemKind elemTy)2427 ArgMaxNode *Function::createArgMax(llvm::StringRef name, NodeValue input,
2428 unsigned_t axis, bool keepDims,
2429 ElemKind elemTy) {
2430 ShapeVector outDims = reduceDims(input.dims(), {axis}, keepDims);
2431 auto OT = getParent()->uniqueType(elemTy, outDims);
2432 return addNode(new ArgMaxNode(name, OT, input, axis, keepDims));
2433 }
2434
createArgMin(llvm::StringRef name,NodeValue input,unsigned_t axis,bool keepDims,ElemKind elemTy)2435 ArgMinNode *Function::createArgMin(llvm::StringRef name, NodeValue input,
2436 unsigned_t axis, bool keepDims,
2437 ElemKind elemTy) {
2438 ShapeVector outDims = reduceDims(input.dims(), {axis}, keepDims);
2439 auto OT = getParent()->uniqueType(elemTy, outDims);
2440 return addNode(new ArgMinNode(name, OT, input, axis, keepDims));
2441 }
2442
createGather(llvm::StringRef name,NodeValue data,NodeValue indices,unsigned_t batchDims)2443 GatherNode *Function::createGather(llvm::StringRef name, NodeValue data,
2444 NodeValue indices, unsigned_t batchDims) {
2445 auto dDims = data.dims();
2446 auto iDims = indices.dims();
2447 assert(dDims.size() > batchDims);
2448 ShapeVector outDims;
2449 outDims.insert(outDims.end(), dDims.begin(), dDims.begin() + batchDims);
2450 outDims.insert(outDims.end(), iDims.begin(), iDims.end());
2451 outDims.insert(outDims.end(), dDims.begin() + batchDims + 1, dDims.end());
2452 return addNode(new GatherNode(
2453 name, getParent()->uniqueTypeWithNewShape(data.getType(), outDims), data,
2454 indices, batchDims));
2455 }
2456
createGatherRanges(llvm::StringRef name,NodeValue data,NodeValue ranges,unsigned_t maxOutputSize)2457 GatherRangesNode *Function::createGatherRanges(llvm::StringRef name,
2458 NodeValue data, NodeValue ranges,
2459 unsigned_t maxOutputSize) {
2460 auto numRanges = ranges.dims()[0];
2461 return addNode(new GatherRangesNode(
2462 name,
2463 /*OutputTy=*/
2464 getParent()->uniqueTypeWithNewShape(data.getType(), {maxOutputSize}),
2465 /*LengthsTy=*/
2466 getParent()->uniqueTypeWithNewShape(ranges.getType(), numRanges), data,
2467 ranges));
2468 }
2469
createScatterData(llvm::StringRef name,NodeValue data,NodeValue indices,NodeValue slices,bool cumulative)2470 ScatterDataNode *Function::createScatterData(llvm::StringRef name,
2471 NodeValue data, NodeValue indices,
2472 NodeValue slices,
2473 bool cumulative) {
2474 return addNode(new ScatterDataNode(name, data, indices, slices, cumulative));
2475 }
2476
createBatchOneHot(llvm::StringRef name,NodeValue data,NodeValue lengths,NodeValue values)2477 BatchOneHotNode *Function::createBatchOneHot(llvm::StringRef name,
2478 NodeValue data, NodeValue lengths,
2479 NodeValue values) {
2480 auto outTy = getParent()->uniqueTypeWithNewShape(
2481 data.getType(), {data.dims()[0], values.dims()[0]});
2482 return addNode(new BatchOneHotNode(name, outTy, data, lengths, values));
2483 }
2484
createSpaceToDepth(llvm::StringRef name,NodeValue input,unsigned blockSize)2485 SpaceToDepthNode *Function::createSpaceToDepth(llvm::StringRef name,
2486 NodeValue input,
2487 unsigned blockSize) {
2488 assert(blockSize > 0 && "BlockSize must be >= 1.");
2489
2490 auto inputDim = input.dims();
2491 assert(inputDim.size() == 4 && "Dimension size of 4 is expected.");
2492 assert((inputDim[1] % blockSize == 0 && inputDim[2] % blockSize == 0) &&
2493 "Height and Width needs to be multiple of blockSize.");
2494 std::vector<dim_t> newDim = {inputDim[0], inputDim[1] / blockSize,
2495 inputDim[2] / blockSize,
2496 inputDim[3] * blockSize * blockSize};
2497 auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim);
2498 return addNode(new SpaceToDepthNode(name, outTy, input, blockSize));
2499 }
2500
createUpsample(llvm::StringRef name,NodeValue input,dim_t numLeadingDims)2501 ReshapeNode *Function::createUpsample(llvm::StringRef name, NodeValue input,
2502 dim_t numLeadingDims) {
2503 auto dims = input.dims();
2504 DCHECK_LE(numLeadingDims, dims.size())
2505 << "numLeadingDims " << numLeadingDims
2506 << " must be less than total num dims " << dims.size();
2507 dim_t dim0Dims = 1;
2508 dim_t dim1Dims = 1;
2509 for (dim_t d = 0; d < dims.size(); d++) {
2510 dim0Dims *= dims[d];
2511 }
2512
2513 Node *cur = input;
2514 for (dim_t d = 0; d < numLeadingDims; d++) {
2515 auto *reshaped =
2516 createReshape(name.str() + "_dim_reshape", cur, {dim0Dims, dim1Dims});
2517 cur =
2518 createTile(name.str() + "_tile", reshaped, /* tiles */ 2, /* axis */ 1);
2519 dim_t sz = dims[dims.size() - d - 1];
2520 dim0Dims /= sz;
2521 dim1Dims *= 2 * sz;
2522 }
2523 std::vector<dim_t> outDims(dims.begin(), dims.end());
2524 for (dim_t d = dims.size() - numLeadingDims; d < dims.size(); d++) {
2525 outDims[d] *= 2;
2526 }
2527 return createReshape(name.str() + "_last_reshape", cur, outDims);
2528 }
2529
createResizeNearest(llvm::StringRef name,NodeValue input,llvm::ArrayRef<float> scale)2530 ResizeNearestNode *Function::createResizeNearest(llvm::StringRef name,
2531 NodeValue input,
2532 llvm::ArrayRef<float> scale) {
2533 auto inputDim = input.dims();
2534 DCHECK_EQ(inputDim.size(), scale.size())
2535 << "Input Dimension size: " << inputDim.size()
2536 << " Scale size: " << scale.size() << " should be same.";
2537
2538 std::vector<dim_t> newDim;
2539
2540 for (size_t i = 0; i < scale.size(); i++) {
2541 auto newD = dim_t(std::floor(inputDim[i] * scale[i]));
2542 DCHECK_GT(newD, 0) << "Scaled dim is " << newD
2543 << ", Scaled value needs to be larger than 0.";
2544 newDim.push_back(newD);
2545 }
2546
2547 auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim);
2548 return addNode(new ResizeNearestNode(name, outTy, input, scale));
2549 }
2550
createResizeNearest(llvm::StringRef name,NodeValue input,TypeRef outTy)2551 ResizeNearestNode *Function::createResizeNearest(llvm::StringRef name,
2552 NodeValue input,
2553 TypeRef outTy) {
2554 auto inputDim = input.dims();
2555 auto outputDim = outTy->dims();
2556 DCHECK_EQ(inputDim.size(), outputDim.size())
2557 << "Input dimension size: " << inputDim.size()
2558 << " output dimension size: " << outputDim.size() << " should be same.";
2559
2560 std::vector<float> scales;
2561 for (size_t i = 0; i < inputDim.size(); i++) {
2562 float scale = (outputDim[i] / (float)inputDim[i]);
2563 DCHECK_GT(scale, 0.0) << "Scale: " << scale
2564 << ", Scale larger than 0 is expected.";
2565 scales.push_back(scale);
2566 }
2567
2568 return addNode(new ResizeNearestNode(name, outTy, input, scales));
2569 }
2570
2571 ResizeBilinearNode *
createResizeBilinear(llvm::StringRef name,NodeValue input,llvm::ArrayRef<float> scale)2572 Function::createResizeBilinear(llvm::StringRef name, NodeValue input,
2573 llvm::ArrayRef<float> scale) {
2574 auto inputDim = input.dims();
2575 DCHECK_EQ(inputDim.size(), scale.size())
2576 << "Input Dimension size: " << inputDim.size()
2577 << " Scale size: " << scale.size() << " should be same.";
2578
2579 std::vector<dim_t> newDim;
2580
2581 for (size_t i = 0; i < scale.size(); i++) {
2582 auto newD = dim_t(std::floor(inputDim[i] * scale[i]));
2583 DCHECK_GT(newD, 0) << "Scaled dim is " << newD
2584 << ", Scaled value needs to be larger than 0.";
2585 newDim.push_back(newD);
2586 }
2587
2588 auto outTy = getParent()->uniqueTypeWithNewShape(input.getType(), newDim);
2589 return addNode(new ResizeBilinearNode(name, outTy, input, scale));
2590 }
2591
createResizeBilinear(llvm::StringRef name,NodeValue input,TypeRef outTy)2592 ResizeBilinearNode *Function::createResizeBilinear(llvm::StringRef name,
2593 NodeValue input,
2594 TypeRef outTy) {
2595 auto inputDim = input.dims();
2596 auto outputDim = outTy->dims();
2597 DCHECK_EQ(inputDim.size(), outputDim.size())
2598 << "Input dimension size: " << inputDim.size()
2599 << " output dimension size: " << outputDim.size() << " should be same.";
2600
2601 std::vector<float> scales;
2602 for (size_t i = 0; i < inputDim.size(); i++) {
2603 float scale = (outputDim[i] / (float)inputDim[i]);
2604 DCHECK_GT(scale, 0.0) << "Scale: " << scale
2605 << ", Scale larger than 0 is expected.";
2606 scales.push_back(scale);
2607 }
2608
2609 return addNode(new ResizeBilinearNode(name, outTy, input, scales));
2610 }
2611
createQuantize(llvm::StringRef name,NodeValue input,TypeRef outTy)2612 QuantizeNode *Function::createQuantize(llvm::StringRef name, NodeValue input,
2613 TypeRef outTy) {
2614 assert(input.getType()->isFPType() && "Input must be a floating type");
2615 assert(outTy->isQuantizedType() && "Output must be a quantized type");
2616 assert(input.dims().equals(outTy->dims()) &&
2617 "Different dimensions for input and output");
2618
2619 return addNode(
2620 new QuantizeNode(name, getParent()->uniqueType(*outTy), input));
2621 }
2622
createDequantize(llvm::StringRef name,NodeValue input,ElemKind k)2623 DequantizeNode *Function::createDequantize(llvm::StringRef name,
2624 NodeValue input, ElemKind k) {
2625 assert(input.getType()->isQuantizedType() &&
2626 "Input must be a quantized type");
2627 assert(isFloatElemKind(k) && "Result must be float type.");
2628 ShapeVector outShape(input.dims().begin(), input.dims().end());
2629 if (input.getElementType() == ElemKind::UInt8FusedQTy) {
2630 assert(outShape.size() == 2 && "Fused tensors should be 2D");
2631 assert(outShape[1] > 2 * sizeof(float) &&
2632 "Expected space for per-row scale/offset");
2633 outShape[1] -= 2 * sizeof(float);
2634 }
2635 TypeRef outTy = getParent()->uniqueType(Type(k, outShape));
2636 return createDequantize(name, input, outTy);
2637 }
2638
createDequantize(llvm::StringRef name,NodeValue input,TypeRef outTy)2639 DequantizeNode *Function::createDequantize(llvm::StringRef name,
2640 NodeValue input, TypeRef outTy) {
2641 assert(input.getType()->isQuantizedType() &&
2642 "Input must be a quantized type");
2643 assert(outTy->isFPType() && "Output should be an FP type");
2644 return addNode(new DequantizeNode(name, outTy, input));
2645 }
2646
createRescaleQuantized(llvm::StringRef name,NodeValue input,TypeRef outTy)2647 RescaleQuantizedNode *Function::createRescaleQuantized(llvm::StringRef name,
2648 NodeValue input,
2649 TypeRef outTy) {
2650 assert(input.getType()->isQuantizedType() &&
2651 "Input must be a quantized type");
2652 assert(outTy->isQuantizedType() && "Output must be a quantized type");
2653 assert(input.dims().equals(outTy->dims()) &&
2654 "Different dimensions for input and output");
2655
2656 return addNode(
2657 new RescaleQuantizedNode(name, getParent()->uniqueType(*outTy), input));
2658 }
2659
createWeightedSum(llvm::StringRef name,llvm::ArrayRef<NodeValue> data,llvm::ArrayRef<NodeValue> weights)2660 Node *Function::createWeightedSum(llvm::StringRef name,
2661 llvm::ArrayRef<NodeValue> data,
2662 llvm::ArrayRef<NodeValue> weights) {
2663 assert(data.size() == weights.size() &&
2664 "Must have same number of data and weights.");
2665 assert(data.size() > 0 && "No inputs provided.");
2666
2667 const auto *outTy = data[0].getType();
2668
2669 // Create a zero splat to bootstrap the adding chain.
2670 Node *currAdd = createSplat(name.str() + ".splat", outTy, 0.);
2671
2672 for (size_t i = 0, e = data.size(); i < e; i++) {
2673 assert(weights[i].getType()->size() == 1 &&
2674 "Each provided weight node must be size 1.");
2675 assert(outTy == data[i].getType() &&
2676 "All data nodes must have the same type.");
2677
2678 // Broadcast the current weight to same shape as the data.
2679 auto *bcastW =
2680 createBroadcast(name.str() + ".bcastWeight" + std::to_string(i),
2681 weights[i], outTy->dims(), /* axis */ 0);
2682
2683 // Element-wise multiply the broadcasted weight by the data.
2684 auto *scaledD =
2685 createMul(name.str() + ".mul" + std::to_string(i), bcastW, data[i]);
2686
2687 // Element-wise add the scaled data to the running total.
2688 currAdd =
2689 createAdd(name.str() + ".add" + std::to_string(i), scaledD, currAdd);
2690 }
2691
2692 // Return the final weighted sum via the last add in the chain.
2693 return currAdd;
2694 }
2695
createBatchBoxCox(llvm::StringRef name,NodeValue data,NodeValue lambda1,NodeValue lambda2,float epsilon)2696 Node *Function::createBatchBoxCox(llvm::StringRef name, NodeValue data,
2697 NodeValue lambda1, NodeValue lambda2,
2698 float epsilon) {
2699 assert((lambda1.dims() == lambda2.dims()) &&
2700 "lambda1 and lambda2 must have the same shape");
2701 assert((lambda1.getType()->getElementType() == lambda2.getElementType()) &&
2702 "lambda1 and lambda2 must have the same element type");
2703 assert((lambda1.getType()->getElementType() == data.getElementType()) &&
2704 "data and lambdas must have the same element type");
2705 assert((lambda1.dims().size() == 1) && "lambda1 and lambda2 must be vectors");
2706 assert((data.dims().size() == 2) && "data must be a matrix");
2707 assert((data.dims()[1] == lambda1.dims()[0]) &&
2708 "data, lambda1 and lambda2 must have the same number of rows");
2709
2710 return addNode(new BatchBoxCoxNode(name, data, lambda1, lambda2, epsilon));
2711 }
2712
createClip(llvm::StringRef name,NodeValue input,TypeRef outTy,float min,float max)2713 ClipNode *Function::createClip(llvm::StringRef name, NodeValue input,
2714 TypeRef outTy, float min, float max) {
2715 return addNode(new ClipNode(name, outTy, input, min, max));
2716 }
2717
createClip(llvm::StringRef name,NodeValue input,float min,float max)2718 ClipNode *Function::createClip(llvm::StringRef name, NodeValue input, float min,
2719 float max) {
2720 return addNode(new ClipNode(name, input.getType(), input, min, max));
2721 }
2722
createClipMinMaxFP16(llvm::StringRef name,NodeValue input)2723 ClipNode *Function::createClipMinMaxFP16(llvm::StringRef name,
2724 NodeValue input) {
2725 constexpr float float16Min = -65504.0f;
2726 constexpr float float16Max = 65504.0f;
2727 return createClip(name, input, float16Min, float16Max);
2728 }
2729
createClipMinMaxBFloat16(llvm::StringRef name,NodeValue input)2730 ClipNode *Function::createClipMinMaxBFloat16(llvm::StringRef name,
2731 NodeValue input) {
2732 constexpr float float16Min = FLT_MIN;
2733 constexpr float float16Max = FLT_MAX;
2734 return createClip(name, input, float16Min, float16Max);
2735 }
2736
2737 //===----------------------------------------------------------------------===//
2738 // Placeholder-builder methods.
2739 //===----------------------------------------------------------------------===//
2740
createBatchNormalization(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,unsigned_t channelIdx,float epsilon,float momentum)2741 BatchNormalizationNode *Function::createBatchNormalization(
2742 PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
2743 unsigned_t channelIdx, float epsilon, float momentum) {
2744 // Figure out how many channels are in the tensor.
2745 dim_t channels = input.dims()[channelIdx];
2746
2747 ElemKind inputTy = input.getType()->getElementType();
2748
2749 // Allocate the learnable parameters beta and gamma.
2750 auto *beta =
2751 getParent()->createPlaceholder(inputTy, {channels}, "beta", true);
2752 bindings.allocate(beta)->init(Tensor::InitKind::Broadcast, 0.1, getPRNG());
2753
2754 auto *scale =
2755 getParent()->createPlaceholder(inputTy, {channels}, "scale", true);
2756 bindings.allocate(scale)->init(Tensor::InitKind::Broadcast, 0.001, getPRNG());
2757
2758 auto *mean =
2759 getParent()->createPlaceholder(inputTy, {channels}, "mean", false);
2760 bindings.allocate(mean)->zero();
2761
2762 auto *variance =
2763 getParent()->createPlaceholder(inputTy, {channels}, "variance", false);
2764 bindings.allocate(variance)->init(Tensor::InitKind::Broadcast, 1.0,
2765 getPRNG());
2766
2767 return createBatchNormalization(name, input, beta, scale, mean, variance,
2768 channelIdx, epsilon, momentum);
2769 }
2770
createConv(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outChannels,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group,unsigned_t dilation,ConvolutionLayout layout)2771 ConvolutionNode *Function::createConv(
2772 PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
2773 dim_t outChannels, llvm::ArrayRef<unsigned_t> kernels,
2774 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
2775 unsigned_t group, unsigned_t dilation, ConvolutionLayout layout) {
2776 ShapeNHWC idim = ShapeNHWC(input.dims());
2777 ShapeHW kdim(kernels);
2778 PaddingTLBR pdim(pads);
2779 (void)pdim;
2780 assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
2781 (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
2782 "buffer too small for selected stride");
2783
2784 assert(group > 0 && "group should be larger than 0");
2785 assert(idim.c % group == 0 && "channels number must be divisible by groups");
2786 assert(outChannels % group == 0 && "outChannels must be divisible by groups");
2787
2788 // Calculate the size and allocate the output buffer.
2789 auto outSz = calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides,
2790 pads, dilation);
2791
2792 std::array<dim_t, 4> outDims = {
2793 {idim.n, outSz.first, outSz.second, outChannels}};
2794
2795 // Allocate the Filter and Bias tensors.
2796 std::array<dim_t, 4> filterDim = {
2797 {outChannels, kdim.height, kdim.width, idim.c / group}};
2798 size_t fanIn = kdim.height * kdim.width * idim.c;
2799 ElemKind inputTy = input.getType()->getElementType();
2800 assert(isFloatElemKind(inputTy) && "Convolution on non-floating point type?");
2801 auto *filter =
2802 getParent()->createPlaceholder(inputTy, filterDim, "filter", true);
2803 bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn,
2804 getPRNG());
2805
2806 auto *bias =
2807 getParent()->createPlaceholder(inputTy, {outChannels}, "bias", true);
2808 bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1,
2809 getPRNG());
2810
2811 auto OT = getParent()->uniqueType(inputTy, outDims);
2812
2813 return addNode(new ConvolutionNode(name, OT, input, filter, bias, kernels,
2814 strides, pads, group, dilation, layout,
2815 FusedActivation::NONE));
2816 }
2817
createConv(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outChannels,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group,unsigned_t dilation,ConvolutionLayout layout)2818 ConvolutionNode *Function::createConv(PlaceholderBindings &bindings,
2819 llvm::StringRef name, NodeValue input,
2820 dim_t outChannels, unsigned_t kernel,
2821 unsigned_t stride, unsigned_t pad,
2822 unsigned_t group, unsigned_t dilation,
2823 ConvolutionLayout layout) {
2824 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
2825 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
2826 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
2827 return createConv(bindings, name, input, outChannels, kernels, strides, pads,
2828 group, dilation, layout);
2829 }
2830
createConv3D(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outChannels,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group)2831 Convolution3DNode *Function::createConv3D(PlaceholderBindings &bindings,
2832 llvm::StringRef name, NodeValue input,
2833 dim_t outChannels,
2834 llvm::ArrayRef<unsigned_t> kernels,
2835 llvm::ArrayRef<unsigned_t> strides,
2836 llvm::ArrayRef<unsigned_t> pads,
2837 unsigned_t group) {
2838 ShapeNTHWC idim(input.dims());
2839 ShapeTHW kdim(kernels);
2840
2841 assert(group > 0 && "group should be larger than 0");
2842 assert(idim.c % group == 0 && "channels number must be divisible by groups");
2843 assert(outChannels % group == 0 && "outChannels must be divisible by groups");
2844
2845 // Calculate the size and allocate the output buffer.
2846 auto outSz = calculate3DConvPoolOutputDims(idim.t, idim.h, idim.w, kernels,
2847 strides, pads);
2848
2849 std::array<dim_t, 5> outDims = {
2850 {idim.n, outSz.temporal_frames, outSz.height, outSz.width, outChannels}};
2851
2852 // Allocate the Filter and Bias tensors.
2853 std::array<dim_t, 5> filterDim = {{outChannels, kdim.temporal_frames,
2854 kdim.height, kdim.width, idim.c / group}};
2855
2856 dim_t fanIn = kdim.temporal_frames * kdim.height * kdim.width * idim.c;
2857 ElemKind inputTy = input.getType()->getElementType();
2858 assert(isFloatElemKind(inputTy) &&
2859 "Convolution3D on non-floating point type?");
2860 auto *filter =
2861 getParent()->createPlaceholder(inputTy, filterDim, "filter", true);
2862 bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn,
2863 getPRNG());
2864
2865 auto *bias =
2866 getParent()->createPlaceholder(inputTy, {outChannels}, "bias", true);
2867 bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1,
2868 getPRNG());
2869
2870 auto OT = getParent()->uniqueType(inputTy, outDims);
2871
2872 assertConv3DDims(input, filter, bias, kernels, strides, pads, group);
2873
2874 return addNode(new Convolution3DNode(name, OT, input, filter, bias, kernels,
2875 strides, pads, group));
2876 }
2877
createConv3D(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,size_t outChannels,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group)2878 Convolution3DNode *Function::createConv3D(PlaceholderBindings &bindings,
2879 llvm::StringRef name, NodeValue input,
2880 size_t outChannels, unsigned_t kernel,
2881 unsigned_t stride, unsigned_t pad,
2882 unsigned_t group) {
2883 llvm::SmallVector<unsigned_t, 6> pads = {pad, pad, pad, pad, pad, pad};
2884 llvm::SmallVector<unsigned_t, 3> strides = {stride, stride, stride};
2885 llvm::SmallVector<unsigned_t, 3> kernels = {kernel, kernel, kernel};
2886 return createConv3D(bindings, name, input, outChannels, kernels, strides,
2887 pads, group);
2888 }
2889
createChannelwiseQuantizedConv(llvm::StringRef name,NodeValue input,NodeValue filter,NodeValue bias,NodeValue filterScales,NodeValue filterOffsets,NodeValue biasScales,NodeValue biasOffsets,TypeRef outTy,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group,unsigned_t dilation,bool quantizeFilter,bool quantizeBias,quantization::Schema schema,ElemKind filterElemQTy,ElemKind biasElemQTy)2890 ChannelwiseQuantizedConvolutionNode *Function::createChannelwiseQuantizedConv(
2891 llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
2892 NodeValue filterScales, NodeValue filterOffsets, NodeValue biasScales,
2893 NodeValue biasOffsets, TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
2894 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
2895 unsigned_t group, unsigned_t dilation, bool quantizeFilter,
2896 bool quantizeBias, quantization::Schema schema, ElemKind filterElemQTy,
2897 ElemKind biasElemQTy) {
2898
2899 // Validate dimensions.
2900 bool isConv3D = (input.getType()->dims().size() == 5);
2901 if (isConv3D) {
2902 assertConv3DDims(input, filter, bias, kernels, strides, pads, group);
2903 } else {
2904 assertConvDims(input, filter, bias, kernels, strides, pads, group);
2905 }
2906
2907 // Validate bias precision.
2908 auto biasElemKind = bias.getElementType();
2909 DCHECK(biasElemKind == ElemKind::Int8QTy ||
2910 biasElemKind == ElemKind::Int32QTy ||
2911 biasElemKind == ElemKind::FloatTy)
2912 << "Unsupported element type for ChannelwiseQuantizedConvolution bias: "
2913 << Type::getElementName(biasElemKind).str();
2914
2915 // Validate filter precision.
2916 auto filterElemKind = filter.getElementType();
2917 DCHECK(filterElemKind == ElemKind::Int8QTy ||
2918 filterElemKind == ElemKind::FloatTy)
2919 << "Unsupported element type for ChannelwiseQuantizedConvolution "
2920 << "filter: " << Type::getElementName(filterElemKind).str();
2921
2922 DCHECK(dyn_cast<Constant>(bias.getNode()))
2923 << "Bias input to ChannelwiseQuantizedConvolutionNode must be a Constant";
2924
2925 DCHECK(dyn_cast<Constant>(filter.getNode()))
2926 << "Filter input to ChannelwiseQuantizedConvolutionNode must be a "
2927 "Constant";
2928
2929 DCHECK(!filterScales.getNode() || dyn_cast<Constant>(filterScales.getNode()))
2930 << "Filter scales input to ChannelwiseQuantizedConvolutionNode must be "
2931 "null or Constant";
2932
2933 DCHECK(!filterOffsets.getNode() ||
2934 dyn_cast<Constant>(filterOffsets.getNode()))
2935 << "Filter offsets input to ChannelwiseQuantizedConvolutionNode must be "
2936 "null or Constant";
2937
2938 DCHECK(!biasScales.getNode() || dyn_cast<Constant>(biasScales.getNode()))
2939 << "Bias scales input to ChannelwiseQuantizedConvolutionNode must be "
2940 "null or Constant";
2941
2942 DCHECK(!biasOffsets.getNode() || dyn_cast<Constant>(biasOffsets.getNode()))
2943 << "Bias offsets input to ChannelwiseQuantizedConvolutionNode must be "
2944 "null or Constant";
2945
2946 // Number of output channels.
2947 dim_t numChannels = outTy->dims().back();
2948 dim_t qDim = 0;
2949 dim_t qStep = 1;
2950
2951 // If input filter is FLOAT and filterScales/filterOffsets are NOT provided
2952 // then compute them automatically for given schema and filterElemQTy.
2953 // If input filter is QUANTIZED then filterScales/filterOffsets are mandatory.
2954 if (!filterScales.getNode() || !filterOffsets.getNode()) {
2955 DCHECK(filterElemKind == ElemKind::FloatTy)
2956 << "ChannelwiseQuantizedConvolution: If the input filter is "
2957 << "quantized then the filter scales/offsets must be provided!";
2958 Constant *filterC = dyn_cast<Constant>(filter.getNode());
2959 Constant *filterScalesC = getParent()->createConstant(
2960 ElemKind::FloatTy, {numChannels}, "filterScales");
2961 Constant *filterOffsetsC = getParent()->createConstant(
2962 ElemKind::Int32ITy, {numChannels}, "filterOffsets");
2963 // Get filter channelwise TensorQuantizationParams.
2964 quantization::getTensorQuantizationParams(
2965 filterC->getPayload(), filterScalesC->getPayloadMutable(),
2966 filterOffsetsC->getPayloadMutable(), schema, filterElemQTy, qDim,
2967 qStep);
2968 filterScales = NodeValue(filterScalesC);
2969 filterOffsets = NodeValue(filterOffsetsC);
2970 }
2971
2972 // If input filter is FLOAT then quantize channel wise to filterElemQTy.
2973 if (quantizeFilter && filterElemKind == ElemKind::FloatTy) {
2974 Constant *filterC = dyn_cast<Constant>(filter.getNode());
2975 Constant *filterCQ = getParent()->createConstant(
2976 filterElemQTy, filterC->getType()->dims(), 1.0, 0, "filter");
2977 Constant *filterScalesC = dyn_cast<Constant>(filterScales.getNode());
2978 Constant *filterOffsetsC = dyn_cast<Constant>(filterOffsets.getNode());
2979 // Quantize filter channelwise.
2980 filterCQ->getPayloadMutable() = quantization::quantizeTensor(
2981 filterC->getPayload(), filterScalesC->getPayload(),
2982 filterOffsetsC->getPayload(), filterElemQTy, qDim, qStep);
2983 filter = NodeValue(filterCQ);
2984 }
2985
2986 // If input bias is FLOAT and biasScales/biasOffsets are NOT provided
2987 // then compute them automatically for given schema and biasElemQTy.
2988 // If input bias is QUANTIZED and biasScales/biasOffsets are NOT provided
2989 // then assume the channel wise quantization parameters are implicitly:
2990 // biasScales[i] = inputScale * filterScales[i] and biasOffsets[i] = 0.
2991 if (!biasScales.getNode() || !biasOffsets.getNode()) {
2992 Constant *biasC = dyn_cast<Constant>(bias.getNode());
2993 Constant *biasScalesC = getParent()->createConstant(
2994 ElemKind::FloatTy, {numChannels}, "biasScales");
2995 Constant *biasOffsetsC = getParent()->createConstant(
2996 ElemKind::Int32ITy, {numChannels}, "biasOffsets");
2997 auto biasScalesH = biasScalesC->getPayload().getHandle<float>();
2998 auto biasOffsetsH = biasOffsetsC->getPayload().getHandle<int32_t>();
2999 Constant *filterScalesC = dyn_cast<Constant>(filterScales.getNode());
3000 Constant *filterOffsetsC = dyn_cast<Constant>(filterOffsets.getNode());
3001 auto filterScalesH = filterScalesC->getPayload().getHandle<float>();
3002 auto filterOffsetsH = filterOffsetsC->getPayload().getHandle<int32_t>();
3003 auto inputScale = input.getType()->getScale();
3004 auto inputOffset = input.getType()->getOffset();
3005 if (biasElemKind == ElemKind::FloatTy) {
3006 // Get bias channelwise TensorQuantizationParams.
3007 quantization::getTensorQuantizationParams(
3008 biasC->getPayload(), biasScalesC->getPayloadMutable(),
3009 biasOffsetsC->getPayloadMutable(), schema, biasElemQTy, qDim, qStep);
3010 // Specialize the bias channelwise TensorQuantizationParams.
3011 for (dim_t idx = 0; idx < numChannels; idx++) {
3012 auto biasTQPNew = specializeBiasQuantizationParams(
3013 {biasScalesH.raw(idx), biasOffsetsH.raw(idx)},
3014 {inputScale, inputOffset},
3015 {filterScalesH.raw(idx), filterOffsetsH.raw(idx)}, schema,
3016 biasElemQTy);
3017 biasScalesH.raw(idx) = biasTQPNew.scale;
3018 biasOffsetsH.raw(idx) = biasTQPNew.offset;
3019 }
3020 } else {
3021 // Set implicit bias channelwise TensorQuantizationParams.
3022 for (dim_t idx = 0; idx < numChannels; idx++) {
3023 float filterScale = filterScalesH.raw(idx);
3024 biasScalesH.raw(idx) = inputScale * filterScale;
3025 biasOffsetsH.raw(idx) = 0;
3026 }
3027 }
3028 biasScales = NodeValue(biasScalesC);
3029 biasOffsets = NodeValue(biasOffsetsC);
3030 }
3031
3032 // If input bias is FLOAT then quantize channel wise to biasElemQTy.
3033 if (quantizeBias && biasElemKind == ElemKind::FloatTy) {
3034 Constant *biasC = dyn_cast<Constant>(bias.getNode());
3035 Constant *biasCQ = getParent()->createConstant(
3036 biasElemQTy, biasC->getType()->dims(), 1.0, 0, "bias");
3037 Constant *biasScalesC = dyn_cast<Constant>(biasScales.getNode());
3038 Constant *biasOffsetsC = dyn_cast<Constant>(biasOffsets.getNode());
3039 // Quantize bias channelwise.
3040 biasCQ->getPayloadMutable() = quantization::quantizeTensor(
3041 biasC->getPayload(), biasScalesC->getPayload(),
3042 biasOffsetsC->getPayload(), biasElemQTy, qDim, qStep);
3043 bias = NodeValue(biasCQ);
3044 }
3045
3046 auto OT = getParent()->uniqueType(*outTy);
3047 return addNode(new ChannelwiseQuantizedConvolutionNode(
3048 name, OT, input, filter, bias, filterScales, filterOffsets, biasScales,
3049 biasOffsets, kernels, strides, pads, group, dilation));
3050 }
3051
createConvTranspose(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outChannels,llvm::ArrayRef<unsigned_t> kernels,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,unsigned_t group,unsigned_t dilation)3052 ConvTransposeNode *Function::createConvTranspose(
3053 PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
3054 dim_t outChannels, llvm::ArrayRef<unsigned_t> kernels,
3055 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
3056 unsigned_t group, unsigned_t dilation) {
3057 ShapeNHWC idim = ShapeNHWC(input.dims());
3058 ShapeHW kdim(kernels);
3059 PaddingTLBR pdim(pads);
3060 (void)pdim;
3061 assert((idim.w + pdim.left + pdim.right) >= kdim.width &&
3062 (idim.h + pdim.top + pdim.bottom) >= kdim.height &&
3063 "buffer too small for selected stride");
3064
3065 assert(group > 0 && "group should be larger than 0");
3066 assert(idim.c % group == 0 && "channels number must be divisible by groups");
3067 assert(outChannels % group == 0 && "outChannels must be divisible by groups");
3068
3069 // Calculate the size and allocate the output buffer.
3070 auto outSz = calculateConvTransposeOutputDims(idim.h, idim.w, kernels,
3071 strides, pads, dilation);
3072
3073 std::array<dim_t, 4> outDims = {
3074 {idim.n, outSz.first, outSz.second, outChannels}};
3075
3076 // Allocate the Filter and Bias tensors.
3077 std::array<dim_t, 4> filterDim = {
3078 {outChannels, kdim.height, kdim.width, idim.c / group}};
3079 size_t fanIn = kdim.height * kdim.width * idim.c;
3080 ElemKind inputTy = input.getType()->getElementType();
3081 assert((inputTy == ElemKind::FloatTy || inputTy == ElemKind::Float16Ty) &&
3082 "Convolution on non-floating point type?");
3083 auto *filter =
3084 getParent()->createPlaceholder(inputTy, filterDim, "filter", true);
3085
3086 auto *bias =
3087 getParent()->createPlaceholder(inputTy, {outChannels}, "bias", true);
3088 bindings.allocate(bias)->init(glow::Tensor::InitKind::Broadcast, 0.1,
3089 getPRNG());
3090
3091 bindings.allocate(filter)->init(glow::Tensor::InitKind::Xavier, fanIn,
3092 getPRNG());
3093
3094 auto OT = getParent()->uniqueType(inputTy, outDims);
3095
3096 return addNode(new ConvTransposeNode(name, OT, input, filter, bias, kernels,
3097 strides, pads, group, dilation));
3098 }
3099
createConvTranspose(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outChannels,unsigned_t kernel,unsigned_t stride,unsigned_t pad,unsigned_t group,unsigned_t dilation)3100 ConvTransposeNode *Function::createConvTranspose(
3101 PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
3102 dim_t outChannels, unsigned_t kernel, unsigned_t stride, unsigned_t pad,
3103 unsigned_t group, unsigned_t dilation) {
3104 llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
3105 llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
3106 llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
3107 return createConvTranspose(bindings, name, input, outChannels, kernels,
3108 strides, pads, group, dilation);
3109 }
3110
createConvertTo(llvm::StringRef name,NodeValue input,TypeRef outTy)3111 ConvertToNode *Function::createConvertTo(llvm::StringRef name, NodeValue input,
3112 TypeRef outTy) {
3113 return addNode(new ConvertToNode(name, outTy, input));
3114 }
3115
createConvertTo(llvm::StringRef name,NodeValue input,ElemKind k)3116 ConvertToNode *Function::createConvertTo(llvm::StringRef name, NodeValue input,
3117 ElemKind k) {
3118 auto OT = getParent()->uniqueType(k, input.dims());
3119 return addNode(new ConvertToNode(name, OT, input));
3120 }
3121
3122 FullyConnectedNode *
createFullyConnected(PlaceholderBindings & bindings,llvm::StringRef name,NodeValue input,dim_t outDepth,unsigned_t axis)3123 Function::createFullyConnected(PlaceholderBindings &bindings,
3124 llvm::StringRef name, NodeValue input,
3125 dim_t outDepth, unsigned_t axis) {
3126 const ElemKind k = input.getType()->getElementType();
3127
3128 // FC always uses 2D input; flatten if necessary.
3129 if (input.dims().size() != 2) {
3130 input = createFlatten(name.str() + ".reshape2D", input, axis);
3131 }
3132 auto *W = getParent()->createPlaceholder(k, {input.dims()[1], outDepth},
3133 "weights", true);
3134 auto *B = getParent()->createPlaceholder(k, {outDepth}, "bias", true);
3135
3136 bindings.allocate(W)->init(Tensor::InitKind::Xavier, input.dims()[1],
3137 getPRNG());
3138 bindings.allocate(B)->init(Tensor::InitKind::Broadcast, .1, getPRNG());
3139
3140 auto OT = getParent()->uniqueType(k, {input.dims()[0], outDepth});
3141 return createFullyConnected(name, input, W, B, OT, axis);
3142 }
3143
createDotProduct(llvm::StringRef name,NodeValue X,NodeValue Y)3144 Node *Function::createDotProduct(llvm::StringRef name, NodeValue X,
3145 NodeValue Y) {
3146 auto XDimsSize = X.dims().size();
3147 (void)XDimsSize;
3148
3149 assert(X.dims() == Y.dims() && "X and Y must have the same shape");
3150 assert(((XDimsSize == 1) || (XDimsSize == 2)) && "X and Y must be 1D or 2D");
3151
3152 // Create Mul node.
3153 auto *MN = createMul(name.str() + ".mul", X, Y);
3154
3155 // If X and Y are 1D, the BatchedReduceAdd node is not needed.
3156 if (XDimsSize == 1) {
3157 return MN;
3158 }
3159
3160 // Create and return BatchedReduceAdd node.
3161 return createBatchedReduceAdd(name.str() + ".bra", MN, 1);
3162 }
3163
3164 BatchedPairwiseDotProductNode *
createBatchedPairwiseDotProduct(llvm::StringRef name,llvm::ArrayRef<NodeValue> inputs)3165 Function::createBatchedPairwiseDotProduct(llvm::StringRef name,
3166 llvm::ArrayRef<NodeValue> inputs) {
3167 assert(!inputs.empty());
3168 dim_t batchCount = inputs[0].getType()->dims()[0];
3169 dim_t numPairs = inputs.size() * (inputs.size() - 1) / 2;
3170 auto *outTy = getParent()->uniqueTypeWithNewShape(inputs[0].getType(),
3171 {batchCount, numPairs});
3172
3173 return addNode(new BatchedPairwiseDotProductNode(name, outTy, inputs));
3174 }
3175
createElementwiseLinear(llvm::StringRef name,NodeValue X,NodeValue w,NodeValue b,unsigned axis)3176 Node *Function::createElementwiseLinear(llvm::StringRef name, NodeValue X,
3177 NodeValue w, NodeValue b,
3178 unsigned axis) {
3179 auto XDims = X.dims();
3180 auto wDims = w.dims();
3181 auto bDims = b.dims();
3182
3183 // Suppress release mode unused variable warnings.
3184 (void)wDims;
3185 (void)bDims;
3186
3187 // Check that the inputs are sensible.
3188 assert(XDims.size() == 2 && "X must be 2D");
3189 assert((axis == 0 || axis == 1) && "axis must be 0 or 1");
3190 assert(wDims.size() == 1 && "w must be 1D");
3191 assert(bDims.size() == 1 && "b must be 1D");
3192 assert(wDims[0] == XDims[axis] &&
3193 "size of w must match input dimension of X");
3194 assert(bDims[0] == XDims[axis] &&
3195 "size of b must match input dimension of X");
3196
3197 // Broadcast w and b so that they have the same dimensions as X.
3198 auto *broadcastW =
3199 createBroadcast(name.str() + ".broadcastW", w, XDims, axis);
3200 auto *broadcastB =
3201 createBroadcast(name.str() + ".broadcastB", b, XDims, axis);
3202
3203 // Implement the elementwise linear operation by multiplying X elementwise
3204 // with broadcasted w and adding broadcasted b elementwise.
3205 auto *wX = createMul(name.str() + ".mul", broadcastW, X);
3206 auto *out = createAdd(name.str() + ".add", wX, broadcastB);
3207
3208 return out;
3209 }
3210
createGRU(PlaceholderBindings & bindings,llvm::StringRef namePrefix,llvm::ArrayRef<NodeValue> inputs,unsigned batchSize,unsigned hiddenSize,unsigned outputSize,std::vector<NodeValue> & outputs)3211 void Function::createGRU(PlaceholderBindings &bindings,
3212 llvm::StringRef namePrefix,
3213 llvm::ArrayRef<NodeValue> inputs, unsigned batchSize,
3214 unsigned hiddenSize, unsigned outputSize,
3215 std::vector<NodeValue> &outputs) {
3216 std::string nameBase = namePrefix;
3217 const unsigned timeSteps = inputs.size();
3218 assert(timeSteps > 0 && "empty input");
3219 const unsigned inputSize = inputs.front().dims().back();
3220 assert(inputSize > 0 && "input dimensionality is zero");
3221
3222 // Initialize the state to zero.
3223 Placeholder *HInit = getParent()->createPlaceholder(
3224 ElemKind::FloatTy, {batchSize, hiddenSize}, "initial_state", false);
3225 bindings.allocate(HInit)->zero();
3226 Node *Ht = HInit;
3227
3228 // Update gate:
3229 // Z <- sigmoid(Wxz * x + Whz * h + bz)
3230 // Reset gate:
3231 // R <- sigmoid(Wxr * x + Whr * h + br)
3232 // Hidden state:
3233 // h <- Z . h + (1 - Z) tanh (Wxh * x + Whh * (R . h) + bh)
3234
3235 // update gate
3236 float bUpdate = 0.1;
3237 Placeholder *Wxz = getParent()->createPlaceholder(
3238 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxz", true);
3239 Placeholder *Whz = getParent()->createPlaceholder(
3240 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whz", true);
3241 Placeholder *Bz1 = getParent()->createPlaceholder(
3242 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bz1", true);
3243 Placeholder *Bz2 = getParent()->createPlaceholder(
3244 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bz2", true);
3245
3246 bindings.allocate(Wxz)->init(glow::Tensor::InitKind::Xavier, inputSize,
3247 getPRNG());
3248 bindings.allocate(Whz)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3249 getPRNG());
3250 bindings.allocate(Bz1)->init(glow::Tensor::InitKind::Broadcast, bUpdate,
3251 getPRNG());
3252 bindings.allocate(Bz2)->init(glow::Tensor::InitKind::Broadcast, bUpdate,
3253 getPRNG());
3254
3255 // Reset gate.
3256 float bReset = -1.0;
3257 Placeholder *Wxr = getParent()->createPlaceholder(
3258 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxr", true);
3259 Placeholder *Whr = getParent()->createPlaceholder(
3260 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whr", true);
3261 Placeholder *Br1 = getParent()->createPlaceholder(
3262 ElemKind::FloatTy, {hiddenSize}, nameBase + ".br1", true);
3263 Placeholder *Br2 = getParent()->createPlaceholder(
3264 ElemKind::FloatTy, {hiddenSize}, nameBase + ".br2", true);
3265
3266 bindings.allocate(Wxr)->init(glow::Tensor::InitKind::Xavier, inputSize,
3267 getPRNG());
3268 bindings.allocate(Whr)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3269 getPRNG());
3270 bindings.allocate(Br1)->init(glow::Tensor::InitKind::Broadcast, bReset,
3271 getPRNG());
3272 bindings.allocate(Br2)->init(glow::Tensor::InitKind::Broadcast, bReset,
3273 getPRNG());
3274
3275 // hidden state
3276 float b = 0.1;
3277 Placeholder *Wxh = getParent()->createPlaceholder(
3278 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxh", true);
3279 Placeholder *Whh = getParent()->createPlaceholder(
3280 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whh", true);
3281 Placeholder *Bh1 = getParent()->createPlaceholder(
3282 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bh1", true);
3283 Placeholder *Bh2 = getParent()->createPlaceholder(
3284 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bh2", true);
3285
3286 bindings.allocate(Wxh)->init(glow::Tensor::InitKind::Xavier, inputSize,
3287 getPRNG());
3288 bindings.allocate(Whh)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3289 getPRNG());
3290 bindings.allocate(Bh1)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3291 bindings.allocate(Bh2)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3292
3293 // Output Layer.
3294 Placeholder *Why = getParent()->createPlaceholder(
3295 ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true);
3296 Placeholder *By = getParent()->createPlaceholder(
3297 ElemKind::FloatTy, {outputSize}, nameBase + ".by", true);
3298
3299 bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3300 getPRNG());
3301 bindings.allocate(By)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3302
3303 auto ty = getParent()->uniqueType(ElemKind::FloatTy, {batchSize, hiddenSize});
3304 auto *Ones = createSplat(nameBase + ".ones", ty, 1.0);
3305
3306 std::vector<Node *> outputNodes;
3307 for (unsigned t = 0; t < timeSteps; t++) {
3308 auto fc1Name = nameBase + ".fc1." + std::to_string(t);
3309 auto fc2Name = nameBase + ".fc2." + std::to_string(t);
3310 auto add1Name = nameBase + ".add1." + std::to_string(t);
3311 auto sigmoid1Name = nameBase + ".sigmoid1." + std::to_string(t);
3312
3313 auto *Zt = createSigmoid(
3314 sigmoid1Name,
3315 createAdd(add1Name, createFullyConnected(fc1Name, Ht, Whz, Bz1),
3316 createFullyConnected(fc2Name, inputs[t], Wxz, Bz2)));
3317
3318 auto fc3Name = nameBase + ".fc3." + std::to_string(t);
3319 auto fc4Name = nameBase + ".fc4." + std::to_string(t);
3320 auto add2Name = nameBase + ".add2." + std::to_string(t);
3321 auto sigmoid2Name = nameBase + ".sigmoid2." + std::to_string(t);
3322
3323 auto *Rt = createSigmoid(
3324 sigmoid2Name,
3325 createAdd(add2Name, createFullyConnected(fc3Name, Ht, Whr, Br1),
3326 createFullyConnected(fc4Name, inputs[t], Wxr, Br2)));
3327
3328 auto zhtName = nameBase + ".zh." + std::to_string(t);
3329 auto *ZHt = createMul(zhtName, Zt, Ht);
3330
3331 auto oneMinusZtName = nameBase + ".1-z." + std::to_string(t);
3332 auto *OneMinusZt = createSub(oneMinusZtName, Ones, Zt);
3333
3334 auto rhtName = nameBase + ".rh." + std::to_string(t);
3335 auto *RHt = createMul(rhtName, Rt, Ht);
3336
3337 auto fc5Name = nameBase + ".fc5." + std::to_string(t);
3338 auto fc6Name = nameBase + ".fc6." + std::to_string(t);
3339 auto add3Name = nameBase + ".add3." + std::to_string(t);
3340 auto tanh1Name = nameBase + ".tanh1." + std::to_string(t);
3341
3342 auto *Ut = createTanh(
3343 tanh1Name,
3344 createAdd(add3Name, createFullyConnected(fc5Name, RHt, Whh, Bh1),
3345 createFullyConnected(fc6Name, inputs[t], Wxh, Bh2)));
3346
3347 auto oneMinusZtUtName = nameBase + "1.-zu." + std::to_string(t);
3348 auto *OneMinusZtUt = createMul(oneMinusZtUtName, OneMinusZt, Ut);
3349
3350 auto htName = nameBase + ".H." + std::to_string(t);
3351 Ht = createAdd(htName, ZHt, OneMinusZtUt);
3352
3353 auto outName = nameBase + ".out." + std::to_string(t);
3354 auto *O = createFullyConnected(outName, Ht, Why, By);
3355 outputs.push_back(O);
3356 }
3357 }
3358
createSimpleRNN(PlaceholderBindings & bindings,llvm::StringRef namePrefix,llvm::ArrayRef<NodeValue> inputs,unsigned batchSize,unsigned hiddenSize,unsigned outputSize,std::vector<NodeValue> & outputs)3359 void Function::createSimpleRNN(PlaceholderBindings &bindings,
3360 llvm::StringRef namePrefix,
3361 llvm::ArrayRef<NodeValue> inputs,
3362 unsigned batchSize, unsigned hiddenSize,
3363 unsigned outputSize,
3364 std::vector<NodeValue> &outputs) {
3365 std::string nameBase = namePrefix;
3366 const unsigned timeSteps = inputs.size();
3367 assert(timeSteps > 0 && "empty input");
3368 const unsigned inputSize = inputs.front().dims().back();
3369 assert(inputSize > 0 && "input dimensionality is zero");
3370
3371 // Initialize the state to zero.
3372 Placeholder *HInit =
3373 getParent()->createPlaceholder(ElemKind::FloatTy, {batchSize, hiddenSize},
3374 nameBase + ".initial_state", false);
3375 bindings.allocate(HInit)->zero();
3376 Node *Ht = HInit;
3377
3378 float b = 0.1;
3379 Placeholder *Whh = getParent()->createPlaceholder(
3380 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whh", true);
3381 Placeholder *Bhh = getParent()->createPlaceholder(
3382 ElemKind::FloatTy, {hiddenSize}, nameBase + ".Bhh", true);
3383 Placeholder *Wxh = getParent()->createPlaceholder(
3384 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxh", true);
3385
3386 Placeholder *Bxh = getParent()->createPlaceholder(
3387 ElemKind::FloatTy, {hiddenSize}, nameBase + ".Bxh", true);
3388 Placeholder *Why = getParent()->createPlaceholder(
3389 ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true);
3390 Placeholder *Bhy = getParent()->createPlaceholder(
3391 ElemKind::FloatTy, {outputSize}, nameBase + ".Bhy", true);
3392
3393 bindings.allocate(Whh)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3394 getPRNG());
3395 bindings.allocate(Bhh)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3396 bindings.allocate(Wxh)->init(glow::Tensor::InitKind::Xavier, inputSize,
3397 getPRNG());
3398 bindings.allocate(Bxh)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3399 bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3400 getPRNG());
3401 bindings.allocate(Bhy)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3402
3403 // Un-roll backpropogation through time as a loop with the shared
3404 // parameters.
3405 for (unsigned t = 0; t < timeSteps; t++) {
3406 auto fc1Name = nameBase + ".fc1." + std::to_string(t);
3407 auto *FC1 = createFullyConnected(fc1Name, Ht, Whh, Bhh);
3408 auto fc2Name = nameBase + ".fc2." + std::to_string(t);
3409 auto *FC2 = createFullyConnected(fc2Name, inputs[t], Wxh, Bxh);
3410 auto aName = nameBase + ".add." + std::to_string(t);
3411 auto *A = createAdd(aName, FC1, FC2);
3412 auto tanhName = nameBase + ".tanh." + std::to_string(t);
3413 auto *H = createTanh(tanhName, A);
3414 auto outName = nameBase + ".out." + std::to_string(t);
3415 auto *O = createFullyConnected(outName, H, Why, Bhy);
3416 outputs.push_back(O);
3417
3418 Ht = H;
3419 };
3420 }
3421
createLSTM(PlaceholderBindings & bindings,llvm::StringRef namePrefix,llvm::ArrayRef<NodeValue> inputs,unsigned batchSize,unsigned hiddenSize,unsigned outputSize,std::vector<NodeValue> & outputs)3422 void Function::createLSTM(PlaceholderBindings &bindings,
3423 llvm::StringRef namePrefix,
3424 llvm::ArrayRef<NodeValue> inputs, unsigned batchSize,
3425 unsigned hiddenSize, unsigned outputSize,
3426 std::vector<NodeValue> &outputs) {
3427 std::string nameBase = namePrefix;
3428 const unsigned timeSteps = inputs.size();
3429 assert(timeSteps > 0 && "empty input");
3430 const unsigned inputSize = inputs.front().dims().back();
3431 assert(inputSize > 0 && "input dimensionality is zero");
3432
3433 // Initialize the hidden and cell states to zero.
3434 Placeholder *HInit =
3435 getParent()->createPlaceholder(ElemKind::FloatTy, {batchSize, hiddenSize},
3436 "initial_hidden_state", false);
3437 bindings.allocate(HInit)->zero();
3438 Node *Ht = HInit;
3439
3440 Placeholder *CInit = getParent()->createPlaceholder(
3441 ElemKind::FloatTy, {batchSize, hiddenSize}, "initial_cell_state", false);
3442 bindings.allocate(CInit)->zero();
3443 Node *Ct = CInit;
3444
3445 // Forget gate:
3446 // F <- sigmoid(Wxf * x + Whf * h + bf)
3447 // Input gate:
3448 // I <- sigmoid(Wxi * x + Whi * h + bi)
3449 // Output gate:
3450 // O <- sigmoid(Wxo * x + Who * h + bi)
3451 // Cell state:
3452 // C <- F . C + I . tanh(Wxc * x + Whc * h + bc)
3453 // Hidden state:
3454 // h <- O . tanh(C)
3455
3456 // forget gate
3457 float bForget = 1.0;
3458 Placeholder *Wxf = getParent()->createPlaceholder(
3459 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxf", true);
3460 Placeholder *Whf = getParent()->createPlaceholder(
3461 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whf", true);
3462 Placeholder *Bf1 = getParent()->createPlaceholder(
3463 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bf1", true);
3464 Placeholder *Bf2 = getParent()->createPlaceholder(
3465 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bf2", true);
3466 bindings.allocate(Wxf)->init(glow::Tensor::InitKind::Xavier, inputSize,
3467 getPRNG());
3468 bindings.allocate(Whf)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3469 getPRNG());
3470 bindings.allocate(Bf1)->init(glow::Tensor::InitKind::Broadcast, bForget,
3471 getPRNG());
3472 bindings.allocate(Bf2)->init(glow::Tensor::InitKind::Broadcast, bForget,
3473 getPRNG());
3474
3475 // input gate
3476 float bInput = 0.1;
3477 Placeholder *Wxi = getParent()->createPlaceholder(
3478 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxi", true);
3479 Placeholder *Whi = getParent()->createPlaceholder(
3480 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whi", true);
3481 Placeholder *Bi1 = getParent()->createPlaceholder(
3482 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bi1", true);
3483 Placeholder *Bi2 = getParent()->createPlaceholder(
3484 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bi2", true);
3485
3486 bindings.allocate(Wxi)->init(glow::Tensor::InitKind::Xavier, inputSize,
3487 getPRNG());
3488 bindings.allocate(Whi)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3489 getPRNG());
3490 bindings.allocate(Bi1)->init(glow::Tensor::InitKind::Broadcast, bInput,
3491 getPRNG());
3492 bindings.allocate(Bi2)->init(glow::Tensor::InitKind::Broadcast, bInput,
3493 getPRNG());
3494
3495 // output gate
3496 float bOutput = 0.1;
3497 Placeholder *Wxo = getParent()->createPlaceholder(
3498 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxo", true);
3499 Placeholder *Who = getParent()->createPlaceholder(
3500 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Who", true);
3501 Placeholder *Bo1 = getParent()->createPlaceholder(
3502 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bo1", true);
3503 Placeholder *Bo2 = getParent()->createPlaceholder(
3504 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bo2", true);
3505
3506 bindings.allocate(Wxo)->init(glow::Tensor::InitKind::Xavier, inputSize,
3507 getPRNG());
3508 bindings.allocate(Who)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3509 getPRNG());
3510 bindings.allocate(Bo1)->init(glow::Tensor::InitKind::Broadcast, bOutput,
3511 getPRNG());
3512 bindings.allocate(Bo2)->init(glow::Tensor::InitKind::Broadcast, bOutput,
3513 getPRNG());
3514
3515 // cell state
3516 float bCell = 0.1;
3517 Placeholder *Wxc = getParent()->createPlaceholder(
3518 ElemKind::FloatTy, {inputSize, hiddenSize}, nameBase + ".Wxc", true);
3519 Placeholder *Whc = getParent()->createPlaceholder(
3520 ElemKind::FloatTy, {hiddenSize, hiddenSize}, nameBase + ".Whc", true);
3521 Placeholder *Bc1 = getParent()->createPlaceholder(
3522 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bc1", true);
3523 Placeholder *Bc2 = getParent()->createPlaceholder(
3524 ElemKind::FloatTy, {hiddenSize}, nameBase + ".bc2", true);
3525
3526 bindings.allocate(Wxc)->init(glow::Tensor::InitKind::Xavier, inputSize,
3527 getPRNG());
3528 bindings.allocate(Whc)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3529 getPRNG());
3530 bindings.allocate(Bc1)->init(glow::Tensor::InitKind::Broadcast, bCell,
3531 getPRNG());
3532 bindings.allocate(Bc2)->init(glow::Tensor::InitKind::Broadcast, bCell,
3533 getPRNG());
3534
3535 // output layer
3536 float b = 0.1;
3537 Placeholder *Why = getParent()->createPlaceholder(
3538 ElemKind::FloatTy, {hiddenSize, outputSize}, nameBase + ".Why", true);
3539 Placeholder *By = getParent()->createPlaceholder(
3540 ElemKind::FloatTy, {outputSize}, nameBase + ".by", true);
3541
3542 bindings.allocate(Why)->init(glow::Tensor::InitKind::Xavier, hiddenSize,
3543 getPRNG());
3544 bindings.allocate(By)->init(glow::Tensor::InitKind::Broadcast, b, getPRNG());
3545
3546 std::vector<Node *> outputNodes;
3547 for (unsigned t = 0; t < timeSteps; t++) {
3548 auto fc1Name = nameBase + ".fc1." + std::to_string(t);
3549 auto fc2Name = nameBase + ".fc2." + std::to_string(t);
3550 auto add1Name = nameBase + ".add1." + std::to_string(t);
3551 auto sigmoid1Name = nameBase + ".sigmoid1." + std::to_string(t);
3552
3553 auto *Ft = createSigmoid(
3554 sigmoid1Name,
3555 createAdd(add1Name, createFullyConnected(fc1Name, Ht, Whf, Bf1),
3556 createFullyConnected(fc2Name, inputs[t], Wxf, Bf2)));
3557
3558 auto fc3Name = nameBase + ".fc3." + std::to_string(t);
3559 auto fc4Name = nameBase + ".fc4." + std::to_string(t);
3560 auto add2Name = nameBase + ".add2." + std::to_string(t);
3561 auto sigmoid2Name = nameBase + ".sigmoid2." + std::to_string(t);
3562
3563 auto *It = createSigmoid(
3564 sigmoid2Name,
3565 createAdd(add2Name, createFullyConnected(fc3Name, Ht, Whi, Bi1),
3566 createFullyConnected(fc4Name, inputs[t], Wxi, Bi2)));
3567
3568 auto fc5Name = nameBase + ".fc5." + std::to_string(t);
3569 auto fc6Name = nameBase + ".fc6." + std::to_string(t);
3570 auto add3Name = nameBase + ".add3." + std::to_string(t);
3571 auto sigmoid3Name = nameBase + ".sigmoid3." + std::to_string(t);
3572
3573 auto *Ot = createSigmoid(
3574 sigmoid3Name,
3575 createAdd(add3Name, createFullyConnected(fc5Name, Ht, Who, Bo1),
3576 createFullyConnected(fc6Name, inputs[t], Wxo, Bo2)));
3577
3578 auto fc7Name = nameBase + ".fc7." + std::to_string(t);
3579 auto fc8Name = nameBase + ".fc8." + std::to_string(t);
3580 auto add4Name = nameBase + ".add4." + std::to_string(t);
3581 auto tanh1Name = nameBase + ".tanh1." + std::to_string(t);
3582
3583 auto *CRt = createTanh(
3584 tanh1Name,
3585 createAdd(add4Name, createFullyConnected(fc7Name, Ht, Whc, Bc1),
3586 createFullyConnected(fc8Name, inputs[t], Wxc, Bc2)));
3587
3588 auto mul1Name = nameBase + ".mul1." + std::to_string(t);
3589 auto mul2Name = nameBase + ".mul2." + std::to_string(t);
3590 Ct = createAdd(nameBase + ".C." + std::to_string(t),
3591 createMul(mul1Name, Ft, Ct), createMul(mul2Name, It, CRt));
3592
3593 auto htName = nameBase + ".H." + std::to_string(t);
3594 auto tanh2Name = nameBase + ".tanh2." + std::to_string(t);
3595 Ht = createMul(htName, Ot, createTanh(tanh2Name, Ct));
3596
3597 auto outName = nameBase + ".out." + std::to_string(t);
3598 auto *O = createFullyConnected(outName, Ht, Why, By);
3599 outputs.push_back(O);
3600 }
3601 };
3602
createOnnxRNN(llvm::StringRef namePrefix,NodeValue X,NodeValue W,NodeValue R,NodeValue B,NodeValue initial_h,NodeValue & Y,NodeValue & Y_h,unsigned hiddenSize,RnnDirection direction,std::vector<RnnActivation> & activations)3603 void Function::createOnnxRNN(llvm::StringRef namePrefix, NodeValue X,
3604 NodeValue W, NodeValue R, NodeValue B,
3605 NodeValue initial_h, NodeValue &Y, NodeValue &Y_h,
3606 unsigned hiddenSize, RnnDirection direction,
3607 std::vector<RnnActivation> &activations) {
3608
3609 #define RNN_X_SLICE_RANGE(idx) \
3610 {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize }
3611 #define RNN_W_SLICE_RANGE(idx0, idx1) \
3612 {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize }
3613 #define RNN_R_SLICE_RANGE(idx0, idx1) \
3614 {idx0, idx1 * hiddenSize, 0}, { \
3615 idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize \
3616 }
3617 #define RNN_B_SLICE_RANGE(idx0, idx1) \
3618 {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
3619 #define RNN_H_SLICE_RANGE(idx) \
3620 {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
3621 #define RNN_CREATE_FC(name, LHS, RHS, BIAS) \
3622 BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS) \
3623 : (Node *)createMatMul(name, LHS, RHS)
3624
3625 // Operator name.
3626 const std::string &opName = namePrefix.str();
3627
3628 // Get all size parameters.
3629 dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1;
3630 assert(X.dims().size() == 3 &&
3631 "ONNX RNN input 'X' should have 3 dimensions!");
3632 dim_t seqLength = X.dims()[0];
3633 dim_t batchSize = X.dims()[1];
3634 dim_t inputSize = X.dims()[2];
3635
3636 // Validate W size.
3637 assert(W.dims().size() == 3 &&
3638 "ONNX RNN input 'W' should have 3 dimensions!");
3639 assert(W.dims()[0] == numDirections && W.dims()[1] == hiddenSize &&
3640 W.dims()[2] == inputSize && "ONNX RNN 'W' tensor size invalid!");
3641
3642 // Validate R size.
3643 assert(R.dims().size() == 3 &&
3644 "ONNX RNN input 'R' should have 3 dimensions!");
3645 assert(R.dims()[0] == numDirections && R.dims()[1] == hiddenSize &&
3646 R.dims()[2] == hiddenSize && "ONNX RNN 'R' tensor size invalid!");
3647
3648 // Validate B size.
3649 if (B.getNode()) {
3650 assert(B.dims().size() == 2 &&
3651 "ONNX RNN input 'B' should have 2 dimensions!");
3652 assert(B.dims()[0] == numDirections && B.dims()[1] == 2 * hiddenSize &&
3653 "ONNX RNN 'B' tensor size invalid!");
3654 }
3655
3656 // Validate initial_h size.
3657 assert(initial_h.getNode() &&
3658 "ONNX RNN input 'initial_h' is mandatory. Null provided!");
3659 assert(initial_h.dims().size() == 3 &&
3660 "ONNX RNN input 'initial_h' should have 2 dimensions!");
3661 assert(initial_h.dims()[0] == numDirections &&
3662 initial_h.dims()[1] == batchSize &&
3663 initial_h.dims()[2] == hiddenSize &&
3664 "ONNX RNN 'initial_h' tensor size invalid!");
3665
3666 // Validate number of activations.
3667 assert(activations.size() == numDirections * 1 &&
3668 "ONNX RNN activations vector invalid!");
3669
3670 // Create X slices.
3671 std::vector<Node *> Xslices;
3672 for (dim_t t = 0; t < seqLength; t++) {
3673 auto XsliceName = opName + ".X" + std::to_string(t) + ".slice";
3674 Node *Xt = createSlice(XsliceName, X, RNN_X_SLICE_RANGE(t));
3675 auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape";
3676 Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize});
3677 Xslices.push_back(Xt);
3678 }
3679
3680 // Lambda to load forward/backward RNN cell.
3681 auto loadRNNCell = [&](bool forward, std::vector<NodeValue> &Yslices,
3682 NodeValue &Hslice) {
3683 // Name prefix.
3684 std::string dirLabel = forward ? ".fw" : ".bw";
3685 std::string prefix = opName + ((numDirections > 1) ? dirLabel : "");
3686
3687 // Slice index used for creating weights slices.
3688 dim_t sliceIdx0 = 0;
3689 if (direction == RnnDirection::Bidirectional) {
3690 sliceIdx0 = forward ? 0 : 1;
3691 }
3692
3693 // Activations.
3694 size_t activationOffset = sliceIdx0 * 1;
3695 auto activationF = activations[activationOffset + 0];
3696
3697 // Create W slice (Required).
3698 NodeValue Wi =
3699 createSlice(prefix + ".Wi.", W, RNN_W_SLICE_RANGE(sliceIdx0, 0));
3700 Wi = createReshape(prefix + ".Wi.reshape", Wi, {hiddenSize, inputSize});
3701 Wi = createTranspose(prefix + ".Wi.transp", Wi, {1, 0});
3702
3703 // Create R slice (Required).
3704 NodeValue Ri =
3705 createSlice(prefix + ".Ri.", R, RNN_R_SLICE_RANGE(sliceIdx0, 0));
3706 Ri = createReshape(prefix + ".Ri.reshape", Ri, {hiddenSize, hiddenSize});
3707 Ri = createTranspose(prefix + ".Ri.transp", Ri, {1, 0});
3708
3709 // Create B slices (optional).
3710 NodeValue bWi = nullptr;
3711 NodeValue bRi = nullptr;
3712
3713 if (B) {
3714
3715 bWi = createSlice(prefix + ".bWi.", B, RNN_B_SLICE_RANGE(sliceIdx0, 0));
3716 bRi = createSlice(prefix + ".bRi.", B, RNN_B_SLICE_RANGE(sliceIdx0, 1));
3717
3718 bWi = createReshape(prefix + ".bWi.reshape", bWi, {hiddenSize});
3719 bRi = createReshape(prefix + ".bRi.reshape", bRi, {hiddenSize});
3720 }
3721
3722 // Create H slice for this direction.
3723 Node *Hinit = createSlice(prefix + ".H.slice", initial_h,
3724 RNN_H_SLICE_RANGE(sliceIdx0));
3725 Hinit =
3726 createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize});
3727
3728 // Initialize.
3729 Node *Ht = Hinit;
3730
3731 // Unroll RNN cell for all time steps.
3732 for (size_t t = 0; t < seqLength; t++) {
3733
3734 // Input for current time step.
3735 // For the reverse RNN cell the inputs are provided in reverse order.
3736 Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t];
3737
3738 // Hidden state update: Ht = f(Xt * Wi + bWi + Ht-1 * Ri + bRi).
3739 Ht = createAdd(prefix + ".H.add",
3740 RNN_CREATE_FC(prefix + ".H.fc1", Xt, Wi, bWi),
3741 RNN_CREATE_FC(prefix + ".H.fc2", Ht, Ri, bRi));
3742 Ht = activationF(prefix + ".H.act", Ht);
3743
3744 // Output.
3745 Yslices.push_back(Ht);
3746 }
3747
3748 // Updated states nodes.
3749 Hslice = Ht;
3750 }; // End of local lambda "loadRNNCell".
3751
3752 bool forwardEnabled = ((direction == RnnDirection::Forward) ||
3753 (direction == RnnDirection::Bidirectional));
3754 bool backwardEnabled = ((direction == RnnDirection::Reverse) ||
3755 (direction == RnnDirection::Bidirectional));
3756
3757 std::vector<NodeValue> YSlices;
3758 std::vector<NodeValue> Hslices;
3759
3760 // Load forward RNN.
3761 std::vector<NodeValue> forwardYslices;
3762 if (forwardEnabled) {
3763 NodeValue forwardHslice;
3764 loadRNNCell(/* forward */ true, forwardYslices, forwardHslice);
3765 Hslices.push_back(forwardHslice);
3766 }
3767
3768 // Load backward RNN.
3769 std::vector<NodeValue> backwardYslices;
3770 if (backwardEnabled) {
3771 NodeValue backwardHslice;
3772 loadRNNCell(/* forward */ false, backwardYslices, backwardHslice);
3773 Hslices.push_back(backwardHslice);
3774 }
3775
3776 // Gather Y slices.
3777 for (size_t t = 0; t < seqLength; t++) {
3778 if (forwardEnabled) {
3779 YSlices.push_back(forwardYslices[t]);
3780 }
3781 if (backwardEnabled) {
3782 YSlices.push_back(backwardYslices[seqLength - 1 - t]);
3783 }
3784 }
3785
3786 // Concatenate Y slices.
3787 // Y size is [seqLength, numDirections, batchSize, hiddenSize].
3788 Y = createReshape(opName + ".Y.reshape",
3789 createConcat(opName + ".Y.concat", YSlices, 0),
3790 {seqLength, numDirections, batchSize, hiddenSize});
3791
3792 // Concatenate Y_h slices.
3793 // Y_h size is [numDirections, batchSize, hiddenSize].
3794 Y_h = createReshape(opName + ".Y_h.reshape",
3795 createConcat(opName + ".Y_h.concat", Hslices, 0),
3796 {numDirections, batchSize, hiddenSize});
3797
3798 #undef RNN_X_SLICE_RANGE
3799 #undef RNN_W_SLICE_RANGE
3800 #undef RNN_R_SLICE_RANGE
3801 #undef RNN_B_SLICE_RANGE
3802 #undef RNN_H_SLICE_RANGE
3803 #undef RNN_CREATE_FC
3804 }
3805
createOnnxGRU(llvm::StringRef namePrefix,NodeValue X,NodeValue W,NodeValue R,NodeValue B,NodeValue initial_h,NodeValue & Y,NodeValue & Y_h,unsigned hiddenSize,RnnDirection direction,std::vector<RnnActivation> & activations,bool linearBeforeReset)3806 void Function::createOnnxGRU(llvm::StringRef namePrefix, NodeValue X,
3807 NodeValue W, NodeValue R, NodeValue B,
3808 NodeValue initial_h, NodeValue &Y, NodeValue &Y_h,
3809 unsigned hiddenSize, RnnDirection direction,
3810 std::vector<RnnActivation> &activations,
3811 bool linearBeforeReset) {
3812
3813 #define GRU_X_SLICE_RANGE(idx) \
3814 {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize }
3815 #define GRU_W_SLICE_RANGE(idx0, idx1) \
3816 {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize }
3817 #define GRU_R_SLICE_RANGE(idx0, idx1) \
3818 {idx0, idx1 * hiddenSize, 0}, { \
3819 idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize \
3820 }
3821 #define GRU_B_SLICE_RANGE(idx0, idx1) \
3822 {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
3823 #define GRU_H_SLICE_RANGE(idx) \
3824 {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
3825 #define GRU_CREATE_FC(name, LHS, RHS, BIAS) \
3826 BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS) \
3827 : (Node *)createMatMul(name, LHS, RHS)
3828
3829 // Operator name.
3830 const std::string &opName = namePrefix.str();
3831
3832 // Get all size parameters.
3833 dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1;
3834 assert(X.dims().size() == 3 &&
3835 "ONNX GRU input 'X' should have 3 dimensions!");
3836 dim_t seqLength = X.dims()[0];
3837 dim_t batchSize = X.dims()[1];
3838 dim_t inputSize = X.dims()[2];
3839
3840 // Validate W size.
3841 assert(W.dims().size() == 3 &&
3842 "ONNX GRU input 'W' should have 3 dimensions!");
3843 assert(W.dims()[0] == numDirections && W.dims()[1] == 3 * hiddenSize &&
3844 W.dims()[2] == inputSize && "ONNX GRU 'W' tensor size invalid!");
3845
3846 // Validate R size.
3847 assert(R.dims().size() == 3 &&
3848 "ONNX GRU input 'R' should have 3 dimensions!");
3849 assert(R.dims()[0] == numDirections && R.dims()[1] == 3 * hiddenSize &&
3850 R.dims()[2] == hiddenSize && "ONNX GRU 'R' tensor size invalid!");
3851
3852 // Validate B size.
3853 if (B.getNode()) {
3854 assert(B.dims().size() == 2 &&
3855 "ONNX GRU input 'B' should have 2 dimensions!");
3856 assert(B.dims()[0] == numDirections && B.dims()[1] == 6 * hiddenSize &&
3857 "ONNX GRU 'B' tensor size invalid!");
3858 }
3859
3860 // Validate initial_h size.
3861 assert(initial_h.getNode() &&
3862 "ONNX GRU input 'initial_h' is mandatory. Null provided!");
3863 assert(initial_h.dims().size() == 3 &&
3864 "ONNX GRU input 'initial_h' should have 2 dimensions!");
3865 assert(initial_h.dims()[0] == numDirections &&
3866 initial_h.dims()[1] == batchSize &&
3867 initial_h.dims()[2] == hiddenSize &&
3868 "ONNX GRU 'initial_h' tensor size invalid!");
3869
3870 // Validate number of activations.
3871 assert(activations.size() == numDirections * 2 &&
3872 "ONNX GRU activations vector invalid!");
3873
3874 // Create X slices.
3875 std::vector<Node *> Xslices;
3876 for (dim_t t = 0; t < seqLength; t++) {
3877 auto XsliceName = opName + ".X" + std::to_string(t) + ".slice";
3878 Node *Xt = createSlice(XsliceName, X, GRU_X_SLICE_RANGE(t));
3879 auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape";
3880 Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize});
3881 Xslices.push_back(Xt);
3882 }
3883
3884 // Lambda to load forward/backward GRU cell.
3885 auto loadGRUCell = [&](bool forward, std::vector<NodeValue> &Yslices,
3886 NodeValue &Hslice) {
3887 // Name prefix.
3888 std::string dirLabel = forward ? ".fw" : ".bw";
3889 std::string prefix = opName + ((numDirections > 1) ? dirLabel : "");
3890
3891 // Slice index used for creating weights slices.
3892 dim_t sliceIdx0 = 0;
3893 if (direction == RnnDirection::Bidirectional) {
3894 sliceIdx0 = forward ? 0 : 1;
3895 }
3896
3897 // Activations.
3898 size_t activationOffset = sliceIdx0 * 2;
3899 auto activationF = activations[activationOffset + 0];
3900 auto activationG = activations[activationOffset + 1];
3901
3902 // Create W slices (Required).
3903 NodeValue Wz =
3904 createSlice(prefix + ".Wz.", W, GRU_W_SLICE_RANGE(sliceIdx0, 0));
3905 NodeValue Wr =
3906 createSlice(prefix + ".Wr.", W, GRU_W_SLICE_RANGE(sliceIdx0, 1));
3907 NodeValue Wh =
3908 createSlice(prefix + ".Wh.", W, GRU_W_SLICE_RANGE(sliceIdx0, 2));
3909
3910 Wz = createReshape(prefix + ".Wz.reshape", Wz, {hiddenSize, inputSize});
3911 Wr = createReshape(prefix + ".Wr.reshape", Wr, {hiddenSize, inputSize});
3912 Wh = createReshape(prefix + ".Wh.reshape", Wh, {hiddenSize, inputSize});
3913
3914 Wz = createTranspose(prefix + ".Wz.transp", Wz, {1, 0});
3915 Wr = createTranspose(prefix + ".Wr.transp", Wr, {1, 0});
3916 Wh = createTranspose(prefix + ".Wh.transp", Wh, {1, 0});
3917
3918 // Create R slices (Required).
3919 NodeValue Rz =
3920 createSlice(prefix + ".Rz.", R, GRU_R_SLICE_RANGE(sliceIdx0, 0));
3921 NodeValue Rr =
3922 createSlice(prefix + ".Rr.", R, GRU_R_SLICE_RANGE(sliceIdx0, 1));
3923 NodeValue Rh =
3924 createSlice(prefix + ".Rh.", R, GRU_R_SLICE_RANGE(sliceIdx0, 2));
3925
3926 Rz = createReshape(prefix + ".Rz.reshape", Rz, {hiddenSize, hiddenSize});
3927 Rr = createReshape(prefix + ".Rr.reshape", Rr, {hiddenSize, hiddenSize});
3928 Rh = createReshape(prefix + ".Rh.reshape", Rh, {hiddenSize, hiddenSize});
3929
3930 Rz = createTranspose(prefix + ".Rz.transp", Rz, {1, 0});
3931 Rr = createTranspose(prefix + ".Rr.transp", Rr, {1, 0});
3932 Rh = createTranspose(prefix + ".Rh.transp", Rh, {1, 0});
3933
3934 // Create B slices (optional).
3935 NodeValue bWz = nullptr;
3936 NodeValue bWr = nullptr;
3937 NodeValue bWh = nullptr;
3938 NodeValue bRz = nullptr;
3939 NodeValue bRr = nullptr;
3940 NodeValue bRh = nullptr;
3941
3942 if (B) {
3943
3944 bWz = createSlice(prefix + ".bWz.", B, GRU_B_SLICE_RANGE(sliceIdx0, 0));
3945 bWr = createSlice(prefix + ".bWr.", B, GRU_B_SLICE_RANGE(sliceIdx0, 1));
3946 bWh = createSlice(prefix + ".bWh.", B, GRU_B_SLICE_RANGE(sliceIdx0, 2));
3947 bRz = createSlice(prefix + ".bRz.", B, GRU_B_SLICE_RANGE(sliceIdx0, 3));
3948 bRr = createSlice(prefix + ".bRr.", B, GRU_B_SLICE_RANGE(sliceIdx0, 4));
3949 bRh = createSlice(prefix + ".bRh.", B, GRU_B_SLICE_RANGE(sliceIdx0, 5));
3950
3951 bWz = createReshape(prefix + ".bWz.reshape", bWz, {hiddenSize});
3952 bWr = createReshape(prefix + ".bWr.reshape", bWr, {hiddenSize});
3953 bWh = createReshape(prefix + ".bWh.reshape", bWh, {hiddenSize});
3954 bRz = createReshape(prefix + ".bRz.reshape", bRz, {hiddenSize});
3955 bRr = createReshape(prefix + ".bRr.reshape", bRr, {hiddenSize});
3956 bRh = createReshape(prefix + ".bRh.reshape", bRh, {hiddenSize});
3957 }
3958
3959 // Create H slice for this direction.
3960 Node *Hinit = createSlice(prefix + ".H.slice", initial_h,
3961 GRU_H_SLICE_RANGE(sliceIdx0));
3962 Hinit =
3963 createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize});
3964
3965 // Initialize.
3966 Node *Ht = Hinit;
3967
3968 // Unroll GRU cell for all time steps.
3969 for (size_t t = 0; t < seqLength; t++) {
3970
3971 // Input for current time step.
3972 // For the reverse GRU cell the inputs are provided in reverse order.
3973 Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t];
3974
3975 // Update gate: zt = f(Xt * Wz + bWz + Ht-1 * Rz + bRz).
3976 Node *zt = createAdd(prefix + ".Z.add1",
3977 GRU_CREATE_FC(prefix + ".Z.fc1", Xt, Wz, bWz),
3978 GRU_CREATE_FC(prefix + ".Z.fc2", Ht, Rz, bRz));
3979 zt = activationF(prefix + ".Z.act", zt);
3980
3981 // Reset gate: rt = f(Xt * Wr + bWr + Ht-1 * Rr + bRr).
3982 Node *rt = createAdd(prefix + ".R.add1",
3983 GRU_CREATE_FC(prefix + ".R.fc1", Xt, Wr, bWr),
3984 GRU_CREATE_FC(prefix + ".R.fc2", Ht, Rr, bRr));
3985 rt = activationF(prefix + ".R.act", rt);
3986
3987 // Hidden gate:
3988 // For linearBeforeReset = true:
3989 // htild = g(Xt * Wh + bWh + rt . (Ht-1 * Rh + bRh)).
3990 // For linearBeforeReset = false:
3991 // htild = g(Xt * Wh + bWh + (rt . Ht-1) * Rh + bRh).
3992 Node *htild;
3993 if (linearBeforeReset) {
3994 htild = createAdd(
3995 prefix + ".Htild.add",
3996 GRU_CREATE_FC(prefix + ".Htild.fc1", Xt, Wh, bWh),
3997 createMul(prefix + ".Htild.reset", rt,
3998 GRU_CREATE_FC(prefix + ".Htild.fc2", Ht, Rh, bRh)));
3999 } else {
4000 htild = createAdd(
4001 prefix + ".Htild.add",
4002 GRU_CREATE_FC(prefix + ".Htild.fc1", Xt, Wh, bWh),
4003 GRU_CREATE_FC(prefix + ".Htild.fc2",
4004 createMul(prefix + ".Htild.reset", rt, Ht), Rh, bRh));
4005 }
4006 htild = activationG(prefix + ".Htild.act", htild);
4007
4008 // Hidden state update:
4009 // Ht = (1 - zt) . htild + zt . Ht-1 = htild - zt . htild + zt . Ht-1.
4010 Ht = createAdd(prefix + ".H.add",
4011 createSub(prefix + ".H.sub", htild,
4012 createMul(prefix + ".H.mult1", zt, htild)),
4013 createMul(prefix + ".H.mult2", zt, Ht));
4014
4015 // Output.
4016 Yslices.push_back(Ht);
4017 }
4018
4019 // Updated states nodes.
4020 Hslice = Ht;
4021 }; // End of local lambda "loadGRUCell".
4022
4023 bool forwardEnabled = ((direction == RnnDirection::Forward) ||
4024 (direction == RnnDirection::Bidirectional));
4025 bool backwardEnabled = ((direction == RnnDirection::Reverse) ||
4026 (direction == RnnDirection::Bidirectional));
4027
4028 std::vector<NodeValue> YSlices;
4029 std::vector<NodeValue> Hslices;
4030
4031 // Load forward GRU.
4032 std::vector<NodeValue> forwardYslices;
4033 if (forwardEnabled) {
4034 NodeValue forwardHslice;
4035 loadGRUCell(/* forward */ true, forwardYslices, forwardHslice);
4036 Hslices.push_back(forwardHslice);
4037 }
4038
4039 // Load backward GRU.
4040 std::vector<NodeValue> backwardYslices;
4041 if (backwardEnabled) {
4042 NodeValue backwardHslice;
4043 loadGRUCell(/* forward */ false, backwardYslices, backwardHslice);
4044 Hslices.push_back(backwardHslice);
4045 }
4046
4047 // Gather Y slices.
4048 for (size_t t = 0; t < seqLength; t++) {
4049 if (forwardEnabled) {
4050 YSlices.push_back(forwardYslices[t]);
4051 }
4052 if (backwardEnabled) {
4053 YSlices.push_back(backwardYslices[seqLength - 1 - t]);
4054 }
4055 }
4056
4057 // Concatenate Y slices.
4058 // Y size is [seqLength, numDirections, batchSize, hiddenSize].
4059 Y = createReshape(opName + ".Y.reshape",
4060 createConcat(opName + ".Y.concat", YSlices, 0),
4061 {seqLength, numDirections, batchSize, hiddenSize});
4062
4063 // Concatenate Y_h slices.
4064 // Y_h size is [numDirections, batchSize, hiddenSize].
4065 Y_h = createReshape(opName + ".Y_h.reshape",
4066 createConcat(opName + ".Y_h.concat", Hslices, 0),
4067 {numDirections, batchSize, hiddenSize});
4068
4069 #undef GRU_X_SLICE_RANGE
4070 #undef GRU_W_SLICE_RANGE
4071 #undef GRU_R_SLICE_RANGE
4072 #undef GRU_B_SLICE_RANGE
4073 #undef GRU_H_SLICE_RANGE
4074 #undef GRU_CREATE_FC
4075 }
4076
createOnnxLSTM(llvm::StringRef namePrefix,NodeValue X,NodeValue W,NodeValue R,NodeValue B,NodeValue initial_h,NodeValue initial_c,NodeValue P,NodeValue & Y,NodeValue & Y_h,NodeValue & Y_c,unsigned hiddenSize,RnnDirection direction,std::vector<RnnActivation> & activations,bool inputForget)4077 void Function::createOnnxLSTM(llvm::StringRef namePrefix, NodeValue X,
4078 NodeValue W, NodeValue R, NodeValue B,
4079 NodeValue initial_h, NodeValue initial_c,
4080 NodeValue P, NodeValue &Y, NodeValue &Y_h,
4081 NodeValue &Y_c, unsigned hiddenSize,
4082 RnnDirection direction,
4083 std::vector<RnnActivation> &activations,
4084 bool inputForget) {
4085
4086 #define LSTM_X_SLICE_RANGE(idx) \
4087 {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize }
4088 #define LSTM_H_SLICE_RANGE(idx) \
4089 {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
4090 #define LSTM_C_SLICE_RANGE(idx) \
4091 {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize }
4092 #define LSTM_W_SLICE_RANGE(idx0, idx1) \
4093 {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize }
4094 #define LSTM_R_SLICE_RANGE(idx0, idx1) \
4095 {idx0, idx1 * hiddenSize, 0}, { \
4096 idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize \
4097 }
4098 #define LSTM_B_SLICE_RANGE(idx0, idx1) \
4099 {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
4100 #define LSTM_P_SLICE_RANGE(idx0, idx1) \
4101 {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize }
4102 #define LSTM_CREATE_FC(name, LHS, RHS, BIAS) \
4103 BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS) \
4104 : (Node *)createMatMul(name, LHS, RHS)
4105
4106 // Operator name.
4107 const std::string &opName = namePrefix.str();
4108
4109 // Get all size parameters.
4110 dim_t numDirections = (direction == RnnDirection::Bidirectional) ? 2 : 1;
4111 assert(X.dims().size() == 3 &&
4112 "ONNX LSTM input 'X' should have 3 dimensions!");
4113 dim_t seqLength = X.dims()[0];
4114 dim_t batchSize = X.dims()[1];
4115 dim_t inputSize = X.dims()[2];
4116
4117 // Validate W size.
4118 assert(W.dims().size() == 3 &&
4119 "ONNX LSTM input 'W' should have 3 dimensions!");
4120 assert(W.dims()[0] == numDirections && W.dims()[1] == 4 * hiddenSize &&
4121 W.dims()[2] == inputSize && "ONNX LSTM 'W' tensor size invalid!");
4122
4123 // Validate R size.
4124 assert(R.dims().size() == 3 &&
4125 "ONNX LSTM input 'R' should have 3 dimensions!");
4126 assert(R.dims()[0] == numDirections && R.dims()[1] == 4 * hiddenSize &&
4127 R.dims()[2] == hiddenSize && "ONNX LSTM 'R' tensor size invalid!");
4128
4129 // Validate B size.
4130 if (B.getNode()) {
4131 assert(B.dims().size() == 2 &&
4132 "ONNX LSTM input 'B' should have 2 dimensions!");
4133 assert(B.dims()[0] == numDirections && B.dims()[1] == 8 * hiddenSize &&
4134 "ONNX LSTM 'B' tensor size invalid!");
4135 }
4136
4137 // Validate initial_h size.
4138 assert(initial_h.getNode() &&
4139 "ONNX LSTM input 'initial_h' is mandatory. Null provided!");
4140 assert(initial_h.dims().size() == 3 &&
4141 "ONNX LSTM input 'initial_h' should have 2 dimensions!");
4142 assert(initial_h.dims()[0] == numDirections &&
4143 initial_h.dims()[1] == batchSize &&
4144 initial_h.dims()[2] == hiddenSize &&
4145 "ONNX LSTM 'initial_h' tensor size invalid!");
4146
4147 // Validate initial_c size.
4148 assert(initial_c.getNode() &&
4149 "ONNX LSTM input 'initial_c' is mandatory. Null provided!");
4150 assert(initial_c.dims().size() == 3 &&
4151 "ONNX LSTM input 'initial_c' should have 2 dimensions!");
4152 assert(initial_c.dims()[0] == numDirections &&
4153 initial_c.dims()[1] == batchSize &&
4154 initial_c.dims()[2] == hiddenSize &&
4155 "ONNX LSTM 'initial_c' tensor size invalid!");
4156
4157 // Validate P size.
4158 if (P.getNode()) {
4159 assert(P.dims().size() == 2 &&
4160 "ONNX LSTM input 'P' should have 2 dimensions!");
4161 assert(P.dims()[0] == numDirections && P.dims()[1] == 3 * hiddenSize &&
4162 "ONNX LSTM 'P' tensor size invalid!");
4163 }
4164
4165 // Validate number of activations.
4166 assert(activations.size() == numDirections * 3 &&
4167 "ONNX LSTM activations vector invalid!");
4168
4169 // Create X slices.
4170 std::vector<Node *> Xslices;
4171 for (dim_t t = 0; t < seqLength; t++) {
4172 auto XsliceName = opName + ".X" + std::to_string(t) + ".slice";
4173 Node *Xt = createSlice(XsliceName, X, LSTM_X_SLICE_RANGE(t));
4174 auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape";
4175 Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize});
4176 Xslices.push_back(Xt);
4177 }
4178
4179 // Lambda to load forward/backward LSTM cell.
4180 auto loadLSTMCell = [&](bool forward, std::vector<NodeValue> &Yslices,
4181 NodeValue &Hslice, NodeValue &Cslice) {
4182 // Name prefix.
4183 std::string dirLabel = forward ? ".fw" : ".bw";
4184 std::string prefix = opName + ((numDirections > 1) ? dirLabel : "");
4185
4186 // Slice index used for creating weights slices.
4187 dim_t sliceIdx0 = 0;
4188 if (direction == RnnDirection::Bidirectional) {
4189 sliceIdx0 = forward ? 0 : 1;
4190 }
4191
4192 // Activations.
4193 size_t activationOffset = sliceIdx0 * 3;
4194 auto activationF = activations[activationOffset + 0];
4195 auto activationG = activations[activationOffset + 1];
4196 auto activationH = activations[activationOffset + 2];
4197
4198 // Create W slices (Required).
4199 NodeValue Wi =
4200 createSlice(prefix + ".Wi.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 0));
4201 NodeValue Wo =
4202 createSlice(prefix + ".Wo.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 1));
4203 NodeValue Wf =
4204 createSlice(prefix + ".Wf.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 2));
4205 NodeValue Wc =
4206 createSlice(prefix + ".Wc.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 3));
4207
4208 Wi = createReshape(prefix + ".Wi.reshape", Wi, {hiddenSize, inputSize});
4209 Wo = createReshape(prefix + ".Wo.reshape", Wo, {hiddenSize, inputSize});
4210 Wf = createReshape(prefix + ".Wf.reshape", Wf, {hiddenSize, inputSize});
4211 Wc = createReshape(prefix + ".Wc.reshape", Wc, {hiddenSize, inputSize});
4212
4213 Wi = createTranspose(prefix + ".Wi.transp", Wi, {1, 0});
4214 Wo = createTranspose(prefix + ".Wo.transp", Wo, {1, 0});
4215 Wf = createTranspose(prefix + ".Wf.transp", Wf, {1, 0});
4216 Wc = createTranspose(prefix + ".Wc.transp", Wc, {1, 0});
4217
4218 // Create R slices (Required).
4219 NodeValue Ri =
4220 createSlice(prefix + ".Ri.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 0));
4221 NodeValue Ro =
4222 createSlice(prefix + ".Ro.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 1));
4223 NodeValue Rf =
4224 createSlice(prefix + ".Rf.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 2));
4225 NodeValue Rc =
4226 createSlice(prefix + ".Rc.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 3));
4227
4228 Ri = createReshape(prefix + ".Ri.reshape", Ri, {hiddenSize, hiddenSize});
4229 Ro = createReshape(prefix + ".Ro.reshape", Ro, {hiddenSize, hiddenSize});
4230 Rf = createReshape(prefix + ".Rf.reshape", Rf, {hiddenSize, hiddenSize});
4231 Rc = createReshape(prefix + ".Rc.reshape", Rc, {hiddenSize, hiddenSize});
4232
4233 Ri = createTranspose(prefix + ".Ri.transp", Ri, {1, 0});
4234 Ro = createTranspose(prefix + ".Ro.transp", Ro, {1, 0});
4235 Rf = createTranspose(prefix + ".Rf.transp", Rf, {1, 0});
4236 Rc = createTranspose(prefix + ".Rc.transp", Rc, {1, 0});
4237
4238 // Create B slices (optional).
4239 NodeValue bWi = nullptr;
4240 NodeValue bWo = nullptr;
4241 NodeValue bWf = nullptr;
4242 NodeValue bWc = nullptr;
4243 NodeValue bRi = nullptr;
4244 NodeValue bRo = nullptr;
4245 NodeValue bRf = nullptr;
4246 NodeValue bRc = nullptr;
4247
4248 if (B) {
4249
4250 bWi = createSlice(prefix + ".bWi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 0));
4251 bWo = createSlice(prefix + ".bWo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 1));
4252 bWf = createSlice(prefix + ".bWf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 2));
4253 bWc = createSlice(prefix + ".bWc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 3));
4254 bRi = createSlice(prefix + ".bRi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 4));
4255 bRo = createSlice(prefix + ".bRo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 5));
4256 bRf = createSlice(prefix + ".bRf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 6));
4257 bRc = createSlice(prefix + ".bRc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 7));
4258
4259 bWi = createReshape(prefix + ".bWi.reshape", bWi, {hiddenSize});
4260 bWo = createReshape(prefix + ".bWo.reshape", bWo, {hiddenSize});
4261 bWf = createReshape(prefix + ".bWf.reshape", bWf, {hiddenSize});
4262 bWc = createReshape(prefix + ".bWc.reshape", bWc, {hiddenSize});
4263 bRi = createReshape(prefix + ".bRi.reshape", bRi, {hiddenSize});
4264 bRo = createReshape(prefix + ".bRo.reshape", bRo, {hiddenSize});
4265 bRf = createReshape(prefix + ".bRf.reshape", bRf, {hiddenSize});
4266 bRc = createReshape(prefix + ".bRc.reshape", bRc, {hiddenSize});
4267 }
4268
4269 // Create P slices (optional).
4270 NodeValue Pi = nullptr;
4271 NodeValue Po = nullptr;
4272 NodeValue Pf = nullptr;
4273
4274 if (P) {
4275
4276 Pi = createSlice(prefix + ".Pi.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 0));
4277 Po = createSlice(prefix + ".Po.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 1));
4278 Pf = createSlice(prefix + ".Pf.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 2));
4279
4280 // Repeat P slices to match [batchSize, hiddenSize].
4281 Pi = createTile(prefix + ".Pi.repeat", Pi, batchSize, 0);
4282 Po = createTile(prefix + ".Po.repeat", Po, batchSize, 0);
4283 Pf = createTile(prefix + ".Pf.repeat", Pf, batchSize, 0);
4284 }
4285
4286 // Create H slice for this direction.
4287 Node *Hinit = createSlice(prefix + ".H.slice", initial_h,
4288 LSTM_H_SLICE_RANGE(sliceIdx0));
4289 Hinit =
4290 createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize});
4291
4292 // Create C slice for this direction.
4293 Node *Cinit = createSlice(prefix + ".C.slice", initial_c,
4294 LSTM_C_SLICE_RANGE(sliceIdx0));
4295 Cinit =
4296 createReshape(prefix + ".C.reshape", Cinit, {batchSize, hiddenSize});
4297
4298 // Initialize.
4299 Node *Ht = Hinit;
4300 Node *Ct = Cinit;
4301
4302 // Unroll LSTM cell for all time steps.
4303 for (size_t t = 0; t < seqLength; t++) {
4304
4305 // Input for current time step.
4306 // For the reverse LSTM cell the inputs are provided in reverse order.
4307 Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t];
4308
4309 // Forget gate: ft = f(Xt * Wf + bWf + Ht-1 * Rf + bRf + Pf . Ct-1).
4310 Node *ft = createAdd(prefix + ".F.add1",
4311 LSTM_CREATE_FC(prefix + ".F.fc1", Xt, Wf, bWf),
4312 LSTM_CREATE_FC(prefix + ".F.fc2", Ht, Rf, bRf));
4313 if (Pf) {
4314 ft = createAdd(prefix + ".F.add2", ft,
4315 createMul(prefix + ".F.mult", Pf, Ct));
4316 }
4317 ft = activationF(prefix + ".F.act", ft);
4318
4319 // Cell state candidate: ctild = g(Xt * Wc + bWc + Ht-1 * Rc + bRc).
4320 Node *ctild =
4321 createAdd(prefix + ".Ctild.add",
4322 LSTM_CREATE_FC(prefix + ".Ctild.fc1", Xt, Wc, bWc),
4323 LSTM_CREATE_FC(prefix + ".Ctild.fc2", Ht, Rc, bRc));
4324 ctild = activationG(prefix + ".Ctild.act", ctild);
4325
4326 // Input gate:
4327 // For inputForget == true:
4328 // it = 1 - ft.
4329 // For inputForget == false:
4330 // it = f(Xt * Wi + bWi + Ht-1 * Ri + bRi + Pi . Ct-1).
4331 Node *it;
4332 if (inputForget) {
4333 auto splatTy = ft->getNthResult(0).getType();
4334 it = createSub(prefix + ".I.sub",
4335 createSplat(prefix + ".I.splat", splatTy, 1.0), ft);
4336 } else {
4337 it = createAdd(prefix + ".I.add1",
4338 LSTM_CREATE_FC(prefix + ".I.fc1", Xt, Wi, bWi),
4339 LSTM_CREATE_FC(prefix + ".I.fc2", Ht, Ri, bRi));
4340 if (Pi) {
4341 it = createAdd(prefix + ".I.add2", it,
4342 createMul(prefix + ".I.mult", Pi, Ct));
4343 }
4344 it = activationF(prefix + ".I.act", it);
4345 }
4346
4347 // Cell state update: Ct = ft . Ct-1 + it . ctild.
4348 Ct = createAdd(prefix + ".C.add", createMul(prefix + ".C.mult1", ft, Ct),
4349 createMul(prefix + ".C.mult2", it, ctild));
4350
4351 // Output gate: ot = f(Xt * Wo + bWo + Ht-1 * Ro + bRo + Po . Ct).
4352 Node *ot = createAdd(prefix + ".O.add1",
4353 LSTM_CREATE_FC(prefix + ".O.fc1", Xt, Wo, bWo),
4354 LSTM_CREATE_FC(prefix + ".O.fc2", Ht, Ro, bRo));
4355 if (Po) {
4356 ot = createAdd(prefix + ".O.add2", ot,
4357 createMul(prefix + ".O.mult", Po, Ct));
4358 }
4359 ot = activationF(prefix + ".O.act", ot);
4360
4361 // Hidden state update: Ht = ot . h(Ct).
4362 Ht =
4363 createMul(prefix + ".H.mult", ot, activationH(prefix + ".H.act", Ct));
4364
4365 // Output.
4366 Yslices.push_back(Ht);
4367 }
4368
4369 // Updated states nodes.
4370 Hslice = Ht;
4371 Cslice = Ct;
4372 }; // End of local lambda "loadLSTMCell".
4373
4374 bool forwardEnabled = ((direction == RnnDirection::Forward) ||
4375 (direction == RnnDirection::Bidirectional));
4376 bool backwardEnabled = ((direction == RnnDirection::Reverse) ||
4377 (direction == RnnDirection::Bidirectional));
4378
4379 std::vector<NodeValue> YSlices;
4380 std::vector<NodeValue> Hslices;
4381 std::vector<NodeValue> Cslices;
4382
4383 // Load forward LSTM.
4384 std::vector<NodeValue> forwardYslices;
4385 if (forwardEnabled) {
4386 NodeValue forwardHslice;
4387 NodeValue forwardCslice;
4388 loadLSTMCell(/* forward */ true, forwardYslices, forwardHslice,
4389 forwardCslice);
4390 Hslices.push_back(forwardHslice);
4391 Cslices.push_back(forwardCslice);
4392 }
4393
4394 // Load backward LSTM.
4395 std::vector<NodeValue> backwardYslices;
4396 if (backwardEnabled) {
4397 NodeValue backwardHslice;
4398 NodeValue backwardCslice;
4399 loadLSTMCell(/* forward */ false, backwardYslices, backwardHslice,
4400 backwardCslice);
4401 Hslices.push_back(backwardHslice);
4402 Cslices.push_back(backwardCslice);
4403 }
4404
4405 // Gather Y slices.
4406 for (size_t t = 0; t < seqLength; t++) {
4407 if (forwardEnabled) {
4408 YSlices.push_back(forwardYslices[t]);
4409 }
4410 if (backwardEnabled) {
4411 YSlices.push_back(backwardYslices[seqLength - 1 - t]);
4412 }
4413 }
4414
4415 // Concatenate Y slices.
4416 // Y size is [seqLength, numDirections, batchSize, hiddenSize].
4417 Y = createReshape(opName + ".Y.reshape",
4418 createConcat(opName + ".Y.concat", YSlices, 0),
4419 {seqLength, numDirections, batchSize, hiddenSize});
4420
4421 // Concatenate Y_h slices.
4422 // Y_h size is [numDirections, batchSize, hiddenSize].
4423 Y_h = createReshape(opName + ".Y_h.reshape",
4424 createConcat(opName + ".Y_h.concat", Hslices, 0),
4425 {numDirections, batchSize, hiddenSize});
4426
4427 // Concatenate Y_c slices.
4428 // Y_c size is [numDirections, batchSize, hiddenSize].
4429 Y_c = createReshape(opName + ".Y_c.reshape",
4430 createConcat(opName + ".Y_c.concat", Cslices, 0),
4431 {numDirections, batchSize, hiddenSize});
4432
4433 #undef LSTM_X_SLICE_RANGE
4434 #undef LSTM_H_SLICE_RANGE
4435 #undef LSTM_C_SLICE_RANGE
4436 #undef LSTM_W_SLICE_RANGE
4437 #undef LSTM_R_SLICE_RANGE
4438 #undef LSTM_B_SLICE_RANGE
4439 #undef LSTM_P_SLICE_RANGE
4440 #undef LSTM_CREATE_FC
4441 }
4442
createTraceEvent(llvm::StringRef eventName,llvm::StringRef eventType,Node * data,unsigned index)4443 TraceEventNode *Function::createTraceEvent(llvm::StringRef eventName,
4444 llvm::StringRef eventType,
4445 Node *data, unsigned index) {
4446 std::string name = (getName() + "_" + eventName + "_instrumentation").str();
4447 return addNode(new TraceEventNode(name, data, eventName, eventType, index));
4448 }
4449
createNonMaxSuppressionV4(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,ElemKind elTy)4450 NonMaxSuppressionNode *Function::createNonMaxSuppressionV4(
4451 llvm::StringRef name, NodeValue boxes, NodeValue scores,
4452 int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
4453 float scoreThreshold, ElemKind elTy) {
4454 // V4
4455 // Class/Score [BatchNum][BoxNum]
4456 // Boxes [BatdhNum][BoxNum][4]
4457 // Result [BatchNum*MaxOutputPerBatch]
4458 // NumberOfIndicesDetected [BatchNum*MaxOutputPerBatch]
4459 auto scoresDim = scores.dims();
4460 int scoresBoxDim = scoresDim.size() - 1;
4461 if (maxOutputBoxesPerClass == 0) {
4462 maxOutputBoxesPerClass = scoresDim[scoresBoxDim];
4463 }
4464
4465 // Allocating maximum because we don't know how many boxes will actually be
4466 // detected.
4467 std::vector<dim_t> newDim = {static_cast<dim_t>(maxOutputBoxesPerClass)};
4468 auto indicesTy = getParent()->uniqueType(elTy, newDim);
4469 auto numberOfSelectedIndicesTy = getParent()->uniqueType(
4470 elTy, {static_cast<dim_t>(maxOutputBoxesPerClass)});
4471 return addNode(new NonMaxSuppressionNode(
4472 name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
4473 maxOutputBoxesPerClass, iouThreshold, scoreThreshold, true));
4474 }
4475
4476 NonMaxSuppressionNode *
createNonMaxSuppressionV4(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold)4477 Function::createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes,
4478 NodeValue scores, int64_t centerPointBox,
4479 int64_t maxOutputBoxesPerClass,
4480 float iouThreshold, float scoreThreshold) {
4481 return createNonMaxSuppressionV4(name, boxes, scores, centerPointBox,
4482 maxOutputBoxesPerClass, iouThreshold,
4483 scoreThreshold, ElemKind::Int64ITy);
4484 }
4485
createNonMaxSuppressionV4(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,TypeRef indicesTy,TypeRef numberOfSelectedIndicesTy)4486 NonMaxSuppressionNode *Function::createNonMaxSuppressionV4(
4487 llvm::StringRef name, NodeValue boxes, NodeValue scores,
4488 int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
4489 float scoreThreshold, TypeRef indicesTy,
4490 TypeRef numberOfSelectedIndicesTy) {
4491 assert(maxOutputBoxesPerClass > 0 && "Invalid maxOutputBoxesPerClass.");
4492
4493 return addNode(new NonMaxSuppressionNode(
4494 name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
4495 maxOutputBoxesPerClass, iouThreshold, scoreThreshold, true));
4496 }
4497
createNonMaxSuppressionONNX(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,ElemKind elTy)4498 NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX(
4499 llvm::StringRef name, NodeValue boxes, NodeValue scores,
4500 int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
4501 float scoreThreshold, ElemKind elTy) {
4502 // ONNX
4503 // Class/Score [BatchNum][ClassNum][BoxNum]
4504 // Box [BatchNum][BoxNum][4]
4505 // Result [BatchNum*MaxOutputPerBatch][3]
4506 auto boxesDim = boxes.dims();
4507 auto scoresDim = scores.dims();
4508 int scoresBoxDim = scoresDim.size() - 1;
4509 int scoresClassDim = scoresDim.size() - 2;
4510 int scoresBatchDim = scoresDim.size() - 3;
4511 int boxesBatchDim = boxesDim.size() - 3;
4512 if (maxOutputBoxesPerClass == 0) {
4513 maxOutputBoxesPerClass = scoresDim[scoresBoxDim];
4514 }
4515
4516 // allocating maximum because we don't know how many boxes will actually be
4517 // detected.
4518 std::vector<dim_t> newDim = {scoresDim[scoresBatchDim] *
4519 scoresDim[scoresClassDim] *
4520 static_cast<dim_t>(maxOutputBoxesPerClass),
4521 3};
4522 auto indicesTy = getParent()->uniqueType(elTy, newDim);
4523 auto numberOfSelectedIndicesTy = getParent()->uniqueType(
4524 elTy,
4525 {boxesDim[boxesBatchDim] * static_cast<dim_t>(maxOutputBoxesPerClass)});
4526 return addNode(new NonMaxSuppressionNode(
4527 name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
4528 maxOutputBoxesPerClass, iouThreshold, scoreThreshold, false));
4529 }
4530
createNonMaxSuppressionONNX(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold)4531 NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX(
4532 llvm::StringRef name, NodeValue boxes, NodeValue scores,
4533 int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
4534 float scoreThreshold) {
4535 return createNonMaxSuppressionONNX(name, boxes, scores, centerPointBox,
4536 maxOutputBoxesPerClass, iouThreshold,
4537 scoreThreshold, ElemKind::Int64ITy);
4538 }
4539
createNonMaxSuppressionONNX(llvm::StringRef name,NodeValue boxes,NodeValue scores,int64_t centerPointBox,int64_t maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,TypeRef indicesTy)4540 NonMaxSuppressionNode *Function::createNonMaxSuppressionONNX(
4541 llvm::StringRef name, NodeValue boxes, NodeValue scores,
4542 int64_t centerPointBox, int64_t maxOutputBoxesPerClass, float iouThreshold,
4543 float scoreThreshold, TypeRef indicesTy) {
4544 auto boxesDim = boxes.dims();
4545 assert(maxOutputBoxesPerClass > 0 && "Invalid maxOutputBoxesPerClass.");
4546
4547 // allocating maximum because we don't know how many boxes will actually be
4548 // detected.
4549 auto numberOfSelectedIndicesTy = getParent()->uniqueType(
4550 ElemKind::Int32ITy, {1, 1, 1,
4551 boxesDim[boxesDim.size() - 2] *
4552 static_cast<dim_t>(maxOutputBoxesPerClass)});
4553 return addNode(new NonMaxSuppressionNode(
4554 name, indicesTy, numberOfSelectedIndicesTy, boxes, scores, centerPointBox,
4555 maxOutputBoxesPerClass, iouThreshold, scoreThreshold, false));
4556 }
4557
createCosineWindow(llvm::StringRef name,dim_t length)4558 Constant *Function::createCosineWindow(llvm::StringRef name, dim_t length) {
4559 auto window = getParent()->createConstant(ElemKind::FloatTy, {length}, name);
4560 auto windowH = window->getHandle<float>();
4561 for (dim_t n = 0; n < length; n++) {
4562 windowH.raw(n) =
4563 0.5 - 0.5 * cos(2.0 * M_PI * (double)(n) / (double)(length));
4564 }
4565 return window;
4566 }
4567
createFFTTwiddleFactors(llvm::StringRef name,dim_t fftLength)4568 Constant *Function::createFFTTwiddleFactors(llvm::StringRef name,
4569 dim_t fftLength) {
4570 auto twiddleFactors =
4571 getParent()->createConstant(ElemKind::FloatTy, {2 * fftLength}, name);
4572 auto twiddleFactorsH = twiddleFactors->getHandle<float>();
4573 for (dim_t k = 0; k < fftLength; k++) {
4574 twiddleFactorsH.raw(2 * k + 0) =
4575 cos(2.0 * M_PI * (double)(k) / (double)(fftLength));
4576 twiddleFactorsH.raw(2 * k + 1) =
4577 -sin(2.0 * M_PI * (double)(k) / (double)(fftLength));
4578 }
4579 return twiddleFactors;
4580 }
4581
createFFTBitReverseIndices(llvm::StringRef name,dim_t fftLength)4582 Constant *Function::createFFTBitReverseIndices(llvm::StringRef name,
4583 dim_t fftLength) {
4584 assert(fftLength >= 1 && "FFT length must be at least 1!");
4585 // Local function to reverse the bits of a number.
4586 auto reverseBits = [](uint64_t bits, dim_t numBits) -> uint64_t {
4587 assert(((0 <= numBits) && (numBits <= 64)) &&
4588 "Maximum number of bits exceeded for 'reverseBits' function!");
4589 if (numBits <= 0) {
4590 return 0;
4591 }
4592 uint64_t bitsRev = 0;
4593 uint64_t bitsMask = 1;
4594 uint64_t bitsRevMask = 1 << (numBits - 1);
4595 for (dim_t idx = 0; idx < numBits; idx++) {
4596 if (bits & bitsMask) {
4597 bitsRev |= bitsRevMask;
4598 }
4599 bitsMask <<= 1;
4600 bitsRevMask >>= 1;
4601 }
4602 return bitsRev;
4603 };
4604 auto bitReverseIndices =
4605 getParent()->createConstant(ElemKind::Int32ITy, {fftLength}, name);
4606 auto bitReverseIndicesH = bitReverseIndices->getHandle<int32_t>();
4607 dim_t numBits = std::log2((double)fftLength);
4608 for (dim_t idx = 0; idx < fftLength; idx++) {
4609 bitReverseIndicesH.raw(idx) =
4610 static_cast<int32_t>(reverseBits(idx, numBits));
4611 }
4612 return bitReverseIndices;
4613 }
4614
createFFTComplexToRealWeights(llvm::StringRef name,dim_t fftLength,dim_t outLength)4615 Constant *Function::createFFTComplexToRealWeights(llvm::StringRef name,
4616 dim_t fftLength,
4617 dim_t outLength) {
4618 auto complexToRealWeights =
4619 getParent()->createConstant(ElemKind::FloatTy, {2 * outLength}, name);
4620 auto complexToRealWeightsH = complexToRealWeights->getHandle<float>();
4621 for (dim_t k = 0; k < outLength; k++) {
4622 complexToRealWeightsH.raw(2 * k + 0) =
4623 0.5 * (1 - sin(2.0 * M_PI * (double)(k) / (double)(fftLength)));
4624 complexToRealWeightsH.raw(2 * k + 1) =
4625 -0.5 * cos(2.0 * M_PI * (double)(k) / (double)(fftLength));
4626 }
4627 return complexToRealWeights;
4628 }
4629
createAudioSpectrogram(llvm::StringRef name,NodeValue input,int64_t windowSize,int64_t windowStride,bool magnitudeSquared)4630 AudioSpectrogramNode *Function::createAudioSpectrogram(llvm::StringRef name,
4631 NodeValue input,
4632 int64_t windowSize,
4633 int64_t windowStride,
4634 bool magnitudeSquared) {
4635 // Output shape will be windowCount x (fftLength / 2 + 1).
4636 dim_t inputLength = input.getType()->size();
4637 dim_t windowCount = std::floor((inputLength - windowSize) / windowStride) + 1;
4638 dim_t fftLength = 1 << (dim_t)std::ceil(std::log2((double)windowSize));
4639 auto spectrogramTy = getParent()->uniqueType(
4640 ElemKind::FloatTy, {windowCount, fftLength / 2 + 1});
4641
4642 // Create a cosine FFT windowing function.
4643 auto window = createCosineWindow(std::string(name) + ".Window", windowSize);
4644
4645 // Create the FFT weights for a fftLength/2 complex FFT.
4646 auto twiddleFactors = createFFTTwiddleFactors(
4647 std::string(name) + ".TwiddleFactors", fftLength / 2);
4648 auto bitReverseIndices = createFFTBitReverseIndices(
4649 std::string(name) + ".BitReverseIndices", fftLength / 2);
4650
4651 // Create the complex to real FFT mapping coefficients.
4652 // For small FFT length make sure to generate at least 1 coefficient.
4653 auto complexToRealWeights = createFFTComplexToRealWeights(
4654 std::string(name) + ".ComplexToRealWeights", fftLength,
4655 (fftLength / 4) >= 1 ? (fftLength / 4) : 1);
4656
4657 // Create AudioSpectrogram node.
4658 return addNode(new AudioSpectrogramNode(
4659 name, spectrogramTy, input, window, twiddleFactors, bitReverseIndices,
4660 complexToRealWeights, windowSize, windowStride, magnitudeSquared));
4661 }
4662
createMelWeights(llvm::StringRef prefix,dim_t spectrogramLength,float sampleRate,float lowerFrequency,float upperFrequency,dim_t filterBankCount,Constant * & melWeights,Constant * & melRanges)4663 void Function::createMelWeights(llvm::StringRef prefix, dim_t spectrogramLength,
4664 float sampleRate, float lowerFrequency,
4665 float upperFrequency, dim_t filterBankCount,
4666 Constant *&melWeights, Constant *&melRanges) {
4667 auto fftLength = 2 * (spectrogramLength - 1);
4668 dim_t numFreqBins = fftLength / 2;
4669 dim_t numMelBins = filterBankCount;
4670
4671 // Mel frequency scale local lambda function.
4672 auto melFreqScale = [](float freq) -> float {
4673 return 1127.0f * logf(1.0f + freq / 700.0f);
4674 };
4675
4676 // Always exclude DC (TensorFlow implementation choice from HTK).
4677 float freqDelta = sampleRate / (float)(fftLength);
4678 dim_t freqIdxMin = (dim_t)(1.5 + (lowerFrequency / freqDelta));
4679 dim_t freqIdxMax = (dim_t)(upperFrequency / freqDelta);
4680 freqIdxMax = (freqIdxMax >= numFreqBins) ? numFreqBins : freqIdxMax;
4681
4682 // Create Mel ranges constant.
4683 melRanges = getParent()->createConstant(ElemKind::Int32ITy, {2 * numMelBins},
4684 std::string(prefix) + ".MelRanges");
4685 auto melRangesH = melRanges->getHandle<int32_t>();
4686
4687 // Mel weights and frequency start/stop (inclusive) buffers.
4688 auto melBinFreqWeights = std::make_unique<float[]>(numMelBins * numFreqBins);
4689 dim_t melBinFreqWeightsNum = 0;
4690
4691 // Mel frequency limits.
4692 float melFreqLower = melFreqScale(lowerFrequency);
4693 float melFreqUpper = melFreqScale(upperFrequency);
4694 float melFreqDelta = (melFreqUpper - melFreqLower) / (numMelBins + 1);
4695 for (dim_t melIdx = 0; melIdx < numMelBins; melIdx++) {
4696
4697 float melFreqLeft = melFreqLower + (melIdx + 0) * melFreqDelta;
4698 float melFreqCenter = melFreqLower + (melIdx + 1) * melFreqDelta;
4699 float melFreqRight = melFreqLower + (melIdx + 2) * melFreqDelta;
4700
4701 int32_t freqIdxStart = -1;
4702 int32_t freqIdxStop = -2;
4703
4704 for (dim_t freqIdx = freqIdxMin; freqIdx <= freqIdxMax; freqIdx++) {
4705 float melFreq = melFreqScale(freqIdx * freqDelta);
4706 if ((melFreqLeft < melFreq) && (melFreq < melFreqRight)) {
4707
4708 // Compute frequency bin weight for this Mel bin.
4709 float weight = 1.0f - std::abs(melFreq - melFreqCenter) / melFreqDelta;
4710
4711 // Store the frequency bin weight.
4712 melBinFreqWeights[melBinFreqWeightsNum++] = weight;
4713
4714 // Update frequency bin start/stop index.
4715 if (freqIdxStart == -1) {
4716 freqIdxStart = freqIdx;
4717 }
4718 freqIdxStop = freqIdx;
4719 }
4720 }
4721
4722 // Store the frequency bin start/stop index.
4723 melRangesH.raw(2 * melIdx + 0) = freqIdxStart;
4724 melRangesH.raw(2 * melIdx + 1) = freqIdxStop;
4725 }
4726
4727 // Validate Mel ranges.
4728 dim_t melBinFreqWeightsNumValidate = 0;
4729 for (dim_t melIdx = 0; melIdx < numMelBins; melIdx++) {
4730 int32_t freqIdxRange =
4731 melRangesH.raw(2 * melIdx + 1) - melRangesH.raw(2 * melIdx + 0) + 1;
4732 melBinFreqWeightsNumValidate += freqIdxRange;
4733 }
4734 assert(melBinFreqWeightsNum == melBinFreqWeightsNumValidate &&
4735 "Invalid Mel ranges");
4736
4737 // Create Mel weights constant.
4738 melWeights =
4739 getParent()->createConstant(ElemKind::FloatTy, {melBinFreqWeightsNum},
4740 std::string(prefix) + ".MelWeights");
4741 auto melWeightsH = melWeights->getHandle<float>();
4742 for (dim_t idx = 0; idx < melBinFreqWeightsNum; idx++) {
4743 melWeightsH.raw(idx) = melBinFreqWeights[idx];
4744 }
4745 }
4746
createDCTMat(llvm::StringRef name,dim_t N,dim_t K)4747 Constant *Function::createDCTMat(llvm::StringRef name, dim_t N, dim_t K) {
4748 Constant *dctMat =
4749 getParent()->createConstant(ElemKind::FloatTy, {K, N}, name);
4750 auto dctMatH = dctMat->getHandle<float>();
4751 float dctFact = (float)sqrt(2.0 / (double)(N));
4752 for (dim_t k = 0; k < K; k++) {
4753 for (dim_t n = 0; n < N; n++) {
4754 dctMatH.at({k, n}) =
4755 dctFact * cos(M_PI / (double)(N) * ((double)(n) + 0.5) * (double)(k));
4756 }
4757 }
4758 return dctMat;
4759 }
4760
createMFCC(llvm::StringRef name,NodeValue spectrogram,float sampleRate,float lowerFrequency,float upperFrequency,int64_t filterBankCount,int64_t numCoefficients)4761 MFCCNode *Function::createMFCC(llvm::StringRef name, NodeValue spectrogram,
4762 float sampleRate, float lowerFrequency,
4763 float upperFrequency, int64_t filterBankCount,
4764 int64_t numCoefficients) {
4765 // Create the Mel weights.
4766 dim_t spectrogramLength = spectrogram.dims()[1];
4767 Constant *melWeights;
4768 Constant *melRanges;
4769 createMelWeights(name, spectrogramLength, sampleRate, lowerFrequency,
4770 upperFrequency, filterBankCount, melWeights, melRanges);
4771
4772 // Create the DCT transform matrix.
4773 Constant *dctMat = createDCTMat(std::string(name) + ".DCTMat",
4774 filterBankCount, numCoefficients);
4775
4776 // Output shape will be windowCount x numCoefficients.
4777 dim_t windowCount = spectrogram.dims()[0];
4778 auto coefficientsTy = getParent()->uniqueType(
4779 ElemKind::FloatTy, {windowCount, static_cast<dim_t>(numCoefficients)});
4780
4781 // Create MFCC node.
4782 return addNode(new MFCCNode(name, coefficientsTy, spectrogram, melWeights,
4783 melRanges, dctMat, sampleRate, lowerFrequency,
4784 upperFrequency, filterBankCount,
4785 numCoefficients));
4786 }
4787
4788 //===----------------------------------------------------------------------===//
4789 // Graph dumping and printing
4790 //===----------------------------------------------------------------------===//
4791
dump() const4792 void Function::dump() const {
4793 llvm::outs() << "Graph structure " << getName() << ":\n";
4794 for (auto &n : nodes_) {
4795 llvm::outs() << n.getDebugDesc();
4796 }
4797 }
4798
toString(bool skipUsersForStorage,bool skipName) const4799 std::string Function::toString(bool skipUsersForStorage, bool skipName) const {
4800 std::string storage;
4801 llvm::raw_string_ostream os(storage);
4802 dump(os, skipUsersForStorage, skipName);
4803 return os.str();
4804 }
4805
getHash() const4806 llvm::hash_code Function::getHash() const {
4807 // Omit function name when generating the hash.
4808 return llvm::hash_value(toString(/* skipUsersForStorage */ false,
4809 /* skipName */ true));
4810 }
4811
dump(llvm::raw_ostream & os,bool skipUsersForStorage,bool skipName) const4812 void Function::dump(llvm::raw_ostream &os, bool skipUsersForStorage,
4813 bool skipName) const {
4814 os << "Graph structure";
4815 if (!skipName) {
4816 os << " " << getName();
4817 }
4818 os << ":\n";
4819 std::set<const Node *, SortNamed> sorted;
4820 for (const Node &n : nodes_) {
4821 sorted.insert(&n);
4822 }
4823 for (auto *n : sorted) {
4824 os << n->getDebugDesc();
4825 }
4826 for (auto *C : getNamedSorted(findConstants())) {
4827 os << C->getDebugDesc(skipUsersForStorage);
4828 }
4829 for (auto *P : getNamedSorted(findPlaceholders())) {
4830 os << P->getDebugDesc(skipUsersForStorage);
4831 }
4832 }
4833
4834 /// We can't use NodeWalker here, because it ignores result indices, which
4835 /// are critical in generating detailed debug output.
4836 class FunctionDottyPrinter : public AbstractDottyPrinter {
4837 // A set of already visited (during graph walk) nodes.
4838 std::unordered_set<Node *> visitedNodes_{};
4839
4840 /// Recursively traverses inputs of node \p N using Deep First Search.
4841 /// Each node will be visited no more than once. The method also dumps
4842 /// edges with their port identifiers in dotty format.
visitNode(Node * N)4843 void visitNode(Node *N) {
4844 if (visitedNodes_.find(N) != visitedNodes_.end())
4845 return;
4846 visitedNodes_.insert(N);
4847
4848 dumpNode(N, false);
4849
4850 // Print edges for the predicate field, if it's used.
4851 if (N->hasPredicate()) {
4852 auto pred = N->getPredicate();
4853 size_t resNo = pred.getResNo();
4854 std::ostringstream edge;
4855 edge << pred.getNode()->getName().str() << ":"
4856 << pred.getNode()->getOutputName(resNo).str() << " -> "
4857 << N->getName().str() << ":w";
4858 dumpEdgeStyle(N, 0, pred, edge);
4859 edges_.insert(edge.str());
4860 visitNode(pred);
4861 }
4862
4863 for (size_t i = 0; i < N->getNumInputs(); i++) {
4864 Node *to = N->getNthInput(i).getNode();
4865 size_t resNo = N->getNthInput(i).getResNo();
4866
4867 std::ostringstream edge;
4868 edge << to->getName().str() << ":" << to->getOutputName(resNo).str()
4869 << " -> " << N->getName().str() << ":" << N->getInputName(i);
4870 dumpEdgeStyle(N, i, to, edge);
4871 edges_.insert(edge.str());
4872
4873 visitNode(to);
4874 }
4875 }
4876
4877 public:
visitGraph(Function * F)4878 void visitGraph(Function *F) {
4879 for (auto &N : F->getNodes()) {
4880 visitNode(&N);
4881 }
4882 }
4883 };
4884
dumpDAG()4885 std::string Function::dumpDAG() {
4886 llvm::SmallString<64> dotPath;
4887 llvm::sys::fs::createTemporaryFile("dotty_graph_dump", "dot", dotPath);
4888 dumpDAG(dotPath);
4889
4890 return std::string(dotPath.begin(), dotPath.end());
4891 }
4892
dumpDAG(llvm::StringRef dotFilename)4893 void Function::dumpDAG(llvm::StringRef dotFilename) {
4894 llvm::outs() << "Writing dotty graph for Function to: " << dotFilename
4895 << '\n';
4896
4897 FunctionDottyPrinter DP;
4898
4899 DP.visitGraph(this);
4900
4901 std::ofstream myfile;
4902 myfile.open(dotFilename);
4903 DP.dumpAll(myfile);
4904 myfile.close();
4905 }
4906
dumpDAG(const char * dotFilename)4907 void Function::dumpDAG(const char *dotFilename) {
4908 dumpDAG(llvm::StringRef(dotFilename));
4909 }
4910
getNodeByName(llvm::StringRef name)4911 Node *Function::getNodeByName(llvm::StringRef name) {
4912 for (auto &N : getNodes()) {
4913 if (N.getName().equals(name)) {
4914 return &N;
4915 }
4916 }
4917 return nullptr;
4918 }
4919
getNodeValueByName(llvm::StringRef name)4920 NodeValue Function::getNodeValueByName(llvm::StringRef name) {
4921 auto strPair = name.split(':');
4922 // Search node, constant or placeholder.
4923 auto nodeName = strPair.first;
4924 Node *node = getNodeByName(nodeName);
4925 node = node ? node : getParent()->getConstantByName(nodeName);
4926 node = node ? node : getParent()->getPlaceholderByNameSlow(nodeName);
4927 if (!node || (node->getNumResults() == 0)) {
4928 return NodeValue();
4929 }
4930 // Get result number.
4931 if (node->getNumResults() == 1) {
4932 return NodeValue(node);
4933 } else {
4934 unsigned resNo = 0;
4935 CHECK(!strPair.second.getAsInteger(0, resNo)) << "Invalid node value name!";
4936 return NodeValue(node, resNo);
4937 }
4938 }
4939
eraseConstant(ConstList::iterator I)4940 void Module::eraseConstant(ConstList::iterator I) {
4941 if (I == constants_.end())
4942 return;
4943 logStorageDeletion(functions_, *I);
4944 delete *I;
4945 constants_.erase(I);
4946 }
4947
erasePlaceholder(PlaceholderList::iterator I)4948 void Module::erasePlaceholder(PlaceholderList::iterator I) {
4949 if (I == placeholders_.end()) {
4950 return;
4951 }
4952
4953 logStorageDeletion(functions_, *I);
4954 delete *I;
4955 placeholders_.erase(I);
4956 }
4957
eraseNode(NodesList::iterator I)4958 void Function::eraseNode(NodesList::iterator I) {
4959 // Log node deletion.
4960 logCtx_->logNodeDeletion(*I);
4961
4962 nodes_.erase(I);
4963 }
4964
getConstantByName(llvm::StringRef name) const4965 Constant *Module::getConstantByName(llvm::StringRef name) const {
4966 for (auto *V : getConstants()) {
4967 if (V->getName() == name)
4968 return V;
4969 }
4970 return nullptr;
4971 }
4972
randomizeConstants(const std::map<Kinded::Kind,std::set<unsigned>> & ignoredConstants)4973 void Function::randomizeConstants(
4974 const std::map<Kinded::Kind, std::set<unsigned>> &ignoredConstants) {
4975 for (Constant *c : getParent()->getConstants()) {
4976 bool usedHere = false;
4977 bool usedElsewhere = false;
4978 bool ignored = false;
4979
4980 for (auto &user : c->getUsers()) {
4981 auto *nodeUser = user.getUser();
4982 if (nodeUser->getParent() == this) {
4983 usedHere = true;
4984 } else {
4985 usedElsewhere = true;
4986 }
4987
4988 auto kind = nodeUser->getKind();
4989 if (ignoredConstants.count(kind)) {
4990 for (auto idx : ignoredConstants.at(kind)) {
4991 if (nodeUser->getNthInput(idx).getNode() == c) {
4992 ignored = true;
4993 break;
4994 }
4995 }
4996 }
4997 }
4998
4999 if (!usedHere) {
5000 continue;
5001 }
5002
5003 if (usedElsewhere) {
5004 LOG(FATAL) << "Can't randomize Constant \"" << c->getName().str()
5005 << "\" because it is used by another function";
5006 }
5007
5008 if (ignored) {
5009 continue;
5010 }
5011
5012 auto &payload = c->getPayloadMutable();
5013
5014 switch (c->getElementType()) {
5015 case ElemKind::FloatTy: {
5016 auto H = payload.getHandle<float>();
5017 auto minMaxArg = H.minMaxArg();
5018 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5019 break;
5020 }
5021 case ElemKind::Float16Ty: {
5022 auto H = payload.getHandle<float16_t>();
5023 auto minMaxArg = H.minMaxArg();
5024 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5025 break;
5026 }
5027 case ElemKind::BFloat16Ty: {
5028 auto H = payload.getHandle<bfloat16_t>();
5029 auto minMaxArg = H.minMaxArg();
5030 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5031 break;
5032 }
5033 case ElemKind::Int8QTy: {
5034 auto H = payload.getHandle<int8_t>();
5035 auto minMaxArg = H.minMaxArg();
5036 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5037 break;
5038 }
5039 case ElemKind::UInt8QTy: {
5040 auto H = payload.getHandle<uint8_t>();
5041 auto minMaxArg = H.minMaxArg();
5042 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5043 break;
5044 }
5045 case ElemKind::Int16QTy: {
5046 auto H = payload.getHandle<int16_t>();
5047 auto minMaxArg = H.minMaxArg();
5048 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5049 break;
5050 }
5051 case ElemKind::Int32QTy: {
5052 auto H = payload.getHandle<int32_t>();
5053 auto minMaxArg = H.minMaxArg();
5054 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5055 break;
5056 }
5057 case ElemKind::Int32ITy: {
5058 auto H = payload.getHandle<int32_t>();
5059 auto minMaxArg = H.minMaxArg();
5060 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5061 break;
5062 }
5063 case ElemKind::Int64ITy: {
5064 auto H = payload.getHandle<int64_t>();
5065 auto minMaxArg = H.minMaxArg();
5066 H.randomize(H.raw(minMaxArg.first), H.raw(minMaxArg.second), getPRNG());
5067 break;
5068 }
5069 case ElemKind::UInt8FusedQTy:
5070 payload.getHandle<uint8_t>().randomize(
5071 std::numeric_limits<uint8_t>::lowest(),
5072 std::numeric_limits<uint8_t>::max(), getPRNG());
5073 break;
5074 case ElemKind::UInt8FusedFP16QTy:
5075 payload.getHandle<uint8_t>().randomize(
5076 std::numeric_limits<uint8_t>::lowest(),
5077 std::numeric_limits<uint8_t>::max(), getPRNG());
5078 break;
5079 case ElemKind::UInt4FusedFP16QTy:
5080 payload.getHandle<uint8_t>().randomize(
5081 std::numeric_limits<uint8_t>::lowest(),
5082 std::numeric_limits<uint8_t>::max(), getPRNG());
5083 break;
5084 case ElemKind::BoolTy:
5085 payload.getHandle<bool>().randomize(false, true, getPRNG());
5086 break;
5087 default:
5088 LOG(FATAL) << "Unsupported ElemKind";
5089 }
5090 }
5091 }
5092
getPlaceholderByNameSlow(llvm::StringRef name) const5093 Placeholder *Module::getPlaceholderByNameSlow(llvm::StringRef name) const {
5094 for (auto *P : getPlaceholders()) {
5095 if (P->getName() == name) {
5096 return P;
5097 }
5098 }
5099
5100 return nullptr;
5101 }
5102
eraseConstant(Constant * N)5103 void Module::eraseConstant(Constant *N) {
5104 auto &vars = getConstants();
5105 auto I = std::find(vars.begin(), vars.end(), N);
5106 eraseConstant(I);
5107 }
5108
eraseNode(Node * N)5109 void Function::eraseNode(Node *N) {
5110 if (Constant *V = dyn_cast<Constant>(N)) {
5111 return getParent()->eraseConstant(V);
5112 }
5113 assert(std::find_if(nodes_.begin(), nodes_.end(),
5114 [N](const Node &node) -> bool { return &node == N; }) !=
5115 nodes_.end() &&
5116 "Could not find node to delete!");
5117 eraseNode(N->getIterator());
5118 }
5119
findPlaceholders()5120 PlaceholderList Function::findPlaceholders() {
5121 PlaceholderList list;
5122 for (auto &PH : parent_->getPlaceholders()) {
5123 for (auto &user : PH->getUsers()) {
5124 if (user.getUser()->getParent() == this) {
5125 list.push_back(PH);
5126 break;
5127 }
5128 }
5129 }
5130 return list;
5131 }
5132
findPlaceholders() const5133 PlaceholderList Function::findPlaceholders() const {
5134 PlaceholderList list;
5135 for (auto &PH : parent_->getPlaceholders()) {
5136 for (auto &user : PH->getUsers()) {
5137 if (user.getUser()->getParent() == this) {
5138 list.push_back(PH);
5139 break;
5140 }
5141 }
5142 }
5143 return list;
5144 }
5145
findConstants()5146 ConstList Function::findConstants() {
5147 ConstList list;
5148 for (auto &constant : parent_->getConstants()) {
5149 for (auto &user : constant->getUsers()) {
5150 if (user.getUser()->getParent() == this) {
5151 list.push_back(constant);
5152 break;
5153 }
5154 }
5155 }
5156 return list;
5157 }
5158
findConstants() const5159 ConstList Function::findConstants() const {
5160 ConstList list;
5161 for (auto &constant : parent_->getConstants()) {
5162 for (auto &user : constant->getUsers()) {
5163 if (user.getUser()->getParent() == this) {
5164 list.push_back(constant);
5165 break;
5166 }
5167 }
5168 }
5169 return list;
5170 }
5171
clone(llvm::StringRef newName,llvm::DenseMap<const Node *,Node * > * map,llvm::DenseMap<const Node *,Node * > * currToNewMap)5172 Function *Function::clone(llvm::StringRef newName,
5173 llvm::DenseMap<const Node *, Node *> *map,
5174 llvm::DenseMap<const Node *, Node *> *currToNewMap) {
5175 Module *M = getParent();
5176 auto *newF = M->createFunction(newName);
5177 return clone(newF, map, currToNewMap);
5178 }
5179
5180 Function *
clone(Function * newF,llvm::DenseMap<const Node *,Node * > * map,llvm::DenseMap<const Node *,Node * > * currToNewMap) const5181 Function::clone(Function *newF, llvm::DenseMap<const Node *, Node *> *map,
5182 llvm::DenseMap<const Node *, Node *> *currToNewMap) const {
5183 // Maps current nodes to new nodes.
5184 llvm::DenseMap<const Node *, Node *> currToNew;
5185
5186 // Initialize the map from a user-provided map.
5187 if (currToNewMap) {
5188 currToNew.insert(currToNewMap->begin(), currToNewMap->end());
5189 }
5190
5191 // Clone all of the nodes in the function.
5192 for (auto &N : getNodes()) {
5193 Node *copy = N.clone();
5194 // Record the copy relationship between the graphs.
5195 currToNew[&N] = copy;
5196 newF->addNode(copy);
5197 if (N.hasPredicate()) {
5198 copy->setPredicate(N.getPredicate());
5199 }
5200 }
5201
5202 // At this point we have a new invalid function that points into nodes in
5203 // the original function. Here we update the links between the nodes in the
5204 // new function.
5205 for (auto &N : newF->getNodes()) {
5206 // Fix each one of the inputs of this node.
5207 for (unsigned inp = 0, e = N.getNumInputs(); inp < e; inp++) {
5208 auto input = N.getNthInput(inp);
5209
5210 auto it = currToNew.find(input.getNode());
5211 if (it == currToNew.end()) {
5212 assert(isa<Storage>(input.getNode()) &&
5213 "Could not find a mapping for some node!");
5214 continue;
5215 }
5216
5217 // Update the node with the edge to the current graph.
5218 N.setNthInput(inp, NodeValue(it->second, input.getResNo()));
5219 }
5220
5221 if (N.hasPredicate()) {
5222 auto it = currToNew.find(N.getPredicate().getNode());
5223 if (it != currToNew.end()) {
5224 N.setPredicate(NodeValue(it->second, N.getPredicate().getResNo()));
5225 }
5226 }
5227 }
5228
5229 // Record the node mapping into the external map.
5230 if (map) {
5231 assert(map->empty() && "The external map must be empty");
5232 for (auto it : currToNew) {
5233 map->insert(it);
5234 }
5235 }
5236
5237 assert(newF->getNodes().size() == getNodes().size() && "Invalid func size");
5238 return newF;
5239 }
5240
5241 /// Verify the input \p idx of a node \p N. Check that the node \p N is in the
5242 /// use-list of the corresponding input node.
verifyNodeInput(const Node & N,size_t idx)5243 static bool verifyNodeInput(const Node &N, size_t idx) {
5244 auto input = N.getNthInput(idx);
5245 auto *refN = input.getNode();
5246 // Check that N is in the use-list of the input node and there is a proper
5247 // entry for it.
5248 for (auto &U : refN->getUsers()) {
5249 if (U.getUser() == &N && *U.get() == input) {
5250 return true;
5251 }
5252 }
5253
5254 report("Any node referencing another node N must be in the use-list of the "
5255 "node N");
5256 return false;
5257 }
5258
clone() const5259 Module *Module::clone() const {
5260 auto *M = new Module;
5261 return clone(M);
5262 }
5263
clone(Module * M) const5264 Module *Module::clone(Module *M) const {
5265 // Maps current nodes to new nodes.
5266 llvm::DenseMap<const Node *, Node *> currToNew;
5267 // Clone placeholders.
5268 for (auto &PH : getPlaceholders()) {
5269 auto *copyPH = M->createPlaceholder(PH->getType(), PH->getName(),
5270 PH->isTraining(), PH->getLayout());
5271 currToNew[PH] = copyPH;
5272 }
5273 // Clone constants.
5274 for (auto &C : getConstants()) {
5275 // Cloner cannot decide on its own what to do with constants having unowned
5276 // payloads. Some kind of policy/hook maybe needed in the future for
5277 // deciding what needs to be done in such cases.
5278 DCHECK(!C->getPayload().isUnowned())
5279 << "Cannot copy constant " << C->getName().str()
5280 << ": Unowned payloads are not supported";
5281 auto *copyC = M->createConstant(C->getType(), C->getName(), C->getLayout());
5282 copyC->assign(&C->getPayload());
5283 currToNew[C] = copyC;
5284 }
5285 // Clone all functions.
5286 for (auto *F : getFunctions()) {
5287 // Create an empty clone function in the new module.
5288 auto *copyF = M->createFunction(F->getName());
5289 // Clone function's body into the newly created cloned function. Use the
5290 // currToNew to properly map constants and placeholders.
5291 F->clone(copyF, nullptr, &currToNew);
5292 // Update all types by cloned types.
5293 for (auto &N : copyF->getNodes()) {
5294 for (unsigned idx = 0, e = N.getNumResults(); idx < e; ++idx) {
5295 N.setType(idx, M->uniqueType(*N.getType(idx)));
5296 }
5297 }
5298 }
5299 return M;
5300 }
5301
5302 /// \returns True if \p n is a storage node (constant or placeholder) of the
5303 /// function \p F.
isGraphStorageNode(Node * n,const Function * F)5304 static bool isGraphStorageNode(Node *n, const Function *F) {
5305 auto &vars = F->getParent()->getConstants();
5306 auto &placeholders = F->getParent()->getPlaceholders();
5307
5308 if (Constant *V = dyn_cast<Constant>(n)) {
5309 return std::find(vars.begin(), vars.end(), V) != vars.end();
5310 }
5311
5312 if (Placeholder *P = dyn_cast<Placeholder>(n)) {
5313 return std::find(placeholders.begin(), placeholders.end(), P) !=
5314 placeholders.end();
5315 }
5316
5317 return false;
5318 }
5319
5320 /// Insert \p node in \p nameToNode and report an error if the insertion fails.
5321 /// \returns True if \p node was inserted into \p nameToNode. False otherwise.
5322 /// When true is returned that means that \p nameToNode had no other nodes
5323 /// registered under \p node.getName().
5324 static bool
insertAndReport(std::unordered_map<std::string,const Node * > & nameToNode,const Node & node,const Function & function)5325 insertAndReport(std::unordered_map<std::string, const Node *> &nameToNode,
5326 const Node &node, const Function &function) {
5327 bool inserted = expectCompareTrue(
5328 "Node is not unique", nameToNode.insert({node.getName(), &node}).second,
5329 true, &function);
5330 if (!inserted) {
5331 std::string storage;
5332 llvm::raw_string_ostream msg(storage);
5333 /// Output extra information helping to find the error.
5334 msg << "The node with name '" << node.getName()
5335 << "' conflicts with a previous definition:\n";
5336 msg << "Current definition: " << node.getDebugDesc() << "\n";
5337 msg << "Previous definition: "
5338 << nameToNode[node.getName()]->getDebugDesc();
5339 report(msg.str().c_str());
5340 return false;
5341 }
5342 return true;
5343 }
5344
verify(const Backend * backend) const5345 bool Function::verify(const Backend *backend) const {
5346 bool isValid = true;
5347 if (backend) {
5348 if (backend->getTensorLayoutRequirements().isEnabled()) {
5349 isValid &= expectCompareTrue(
5350 "Expected correct backend-specific layouts for the graph",
5351 verifyLayouts(*this, backend->getTensorLayoutRequirements()), true,
5352 this);
5353 }
5354 } else {
5355 // Always run verification pre-lowering / when we don't have backend:
5356 isValid &= expectCompareTrue(
5357 "Expected correct Glow canonical layouts for the graph",
5358 verifyLayouts(*this, CanonicalTensorLayout::getInstance()), true, this);
5359 }
5360 std::unordered_map<std::string, const Node *> nameToNode;
5361
5362 for (auto *V : findConstants()) {
5363 isValid &= insertAndReport(nameToNode, *V, *this);
5364 isValid &= expectCompareTrue("Constant and its payload must have same type",
5365 *V->getType(), V->getPayload().getType(), V);
5366 }
5367
5368 nameToNode.clear();
5369 for (const auto &N : nodes_) {
5370 isValid &= insertAndReport(nameToNode, N, *this);
5371 }
5372
5373 // Any node referenced by one of the graph nodes should be part of the
5374 // Graph.
5375 for (const auto &N : nodes_) {
5376 for (size_t idx = 0, e = N.getNumInputs(); idx < e; ++idx) {
5377 auto &input = N.getNthInput(idx);
5378 // Verify each input of N.
5379 isValid &= verifyNodeInput(N, idx);
5380 bool foundNode =
5381 std::find(nodes_.begin(), nodes_.end(), *input) != nodes_.end();
5382 isValid &= expectCompareTrue(
5383 "Every node referenced by one of the graph nodes should be part of "
5384 "the graph",
5385 foundNode || isGraphStorageNode(input, this), true, &N);
5386 }
5387 }
5388
5389 // Check that all uses of each node refer to this node.
5390 for (const auto &N : nodes_) {
5391 for (const auto &U : N.getUsers()) {
5392 isValid &= expectCompareTrue<const Node *>(
5393 "All uses of a node should refer to this node", U.get()->getNode(),
5394 &N, &N);
5395 ;
5396 }
5397 }
5398
5399 // Check that all types used by nodes belong to the parent module.
5400 auto &types = getParent()->getTypes();
5401 for (const auto &N : nodes_) {
5402 for (size_t idx = 0, e = N.getNumResults(); idx < e; ++idx) {
5403 auto ty = N.getType(idx);
5404 bool foundType =
5405 std::find(types.begin(), types.end(), *ty) != types.end();
5406 isValid &= expectCompareTrue(
5407 "Every type used by one of the graph nodes should be part of "
5408 "the graph",
5409 foundType, true, &N);
5410 }
5411 }
5412
5413 std::unordered_map<const Placeholder *, const Node *> placeholderWrittenTo;
5414 for (const auto &N : nodes_) {
5415 isValid &=
5416 expectCompareTrue("Node is not linked to the function it belongs",
5417 N.getParent(), this, &N);
5418 isValid &= N.verify();
5419 // Make sure all the placeholders are at most written once, and that
5420 // constants are never written to.
5421 for (size_t idx = 0, e = N.getNumInputs(); idx < e; ++idx) {
5422 // Placeholders and Constants have no input, so they can only be
5423 // written to via overwritten inputs.
5424 if (!N.isOverwrittenNthInput(idx)) {
5425 continue;
5426 }
5427
5428 const Node *nthInputNode = N.getNthInput(idx).getNode();
5429 isValid &= expectCompareTrue(
5430 "Constants can never be used as an overwritten input",
5431 isa<Constant>(nthInputNode), false, nthInputNode);
5432
5433 // Unlike Constants, Placeholders can be used at most once as
5434 // overwritten inputs. Keep a map of Placeholders to Nodes that used
5435 // them as overwritten inputs, which is also used later to check for
5436 // read-after-write dependence violations.
5437 const auto *ph = dyn_cast<Placeholder>(nthInputNode);
5438 if (!ph) {
5439 continue;
5440 }
5441 auto varToFirstDef = placeholderWrittenTo.find(ph);
5442 bool writtenOnce = expectCompareTrue(
5443 "Placeholder has more than one write",
5444 varToFirstDef == placeholderWrittenTo.end(), true, ph);
5445 if (!writtenOnce) {
5446 isValid = false;
5447 std::string storage;
5448 llvm::raw_string_ostream msg(storage);
5449
5450 msg << "Placeholder " << ph->getDebugDesc() << '\n';
5451 msg << "has more than one write; second writer found:\n";
5452 msg << N.getDebugDesc() << '\n';
5453 msg << varToFirstDef->second->getDebugDesc() << '\n';
5454
5455 report(msg.str().c_str());
5456 }
5457
5458 placeholderWrittenTo[ph] = &N;
5459 }
5460 }
5461
5462 // Now check that the placeholders that are written to are either:
5463 // - Written by a save node, or
5464 // - Are only used by the node that writes them
5465 // If this check fails, that means we have implicit memory
5466 // dependencies that may not be honored by the scheduler.
5467 // Either the input IR is incorrect or the scheduler needs
5468 // fixing.
5469 for (const std::pair<const Placeholder *, const Node *> &varToWrite :
5470 placeholderWrittenTo) {
5471 if (isa<SaveNode>(varToWrite.second)) {
5472 continue;
5473 }
5474 for (const NodeUse &use : varToWrite.first->getUsers()) {
5475 const Node *user = use.getUser();
5476 // Ignore users outside this function.
5477 if (user->getParent() != this) {
5478 continue;
5479 }
5480 isValid &= expectCompareTrue(
5481 "Implicit read after write memory dependency may not be honored",
5482 user, varToWrite.second, user);
5483 }
5484 }
5485 return isValid;
5486 }
5487
getOutputSave(Function * F,Placeholder * PH)5488 SaveNode *glow::getOutputSave(Function *F, Placeholder *PH) {
5489 // if parent is set for PH, check if it is the same as provided Function.
5490 auto *PHP = PH->getParent();
5491 if (PHP != nullptr && F != PHP) {
5492 return nullptr;
5493 }
5494 for (auto &use : PH->getUsers()) {
5495 if (auto *save = llvm::dyn_cast<SaveNode>(use.getUser())) {
5496 if (save->getParent() == F && save->getPlaceholder() == PH) {
5497 return save;
5498 }
5499 }
5500 }
5501 return nullptr;
5502 }
5503
recursiveClone(Function * newF,Node * node,NodeMap & currToNew)5504 Node *glow::recursiveClone(Function *newF, Node *node, NodeMap &currToNew) {
5505 Node *copy = node->clone();
5506 currToNew[node] = copy;
5507 newF->addNode(copy);
5508 for (unsigned inp = 0, e = copy->getNumInputs(); inp < e; inp++) {
5509 auto input = copy->getNthInput(inp);
5510 auto it = currToNew.find(input.getNode());
5511 Node *newInput;
5512 if (it != currToNew.end()) {
5513 newInput = it->second;
5514 } else if (llvm::isa<Storage>(input.getNode())) {
5515 continue;
5516 } else {
5517 newInput = recursiveClone(newF, input.getNode(), currToNew);
5518 }
5519 copy->setNthInput(inp, NodeValue(newInput, input.getResNo()));
5520 }
5521 return copy;
5522 }
5523
5524 namespace glow {
5525 /// If \p PH is an output placeholder, \returns true.
5526 /// This is determined by checking if the PH has a user which uses the PH as an
5527 /// overwritten input.
isOutput(const Placeholder * PH,const Function & F)5528 bool isOutput(const Placeholder *PH, const Function &F) {
5529 for (const auto &use : PH->getUsers()) {
5530 // Look through the inputs of the PH's users. If an input is overwritten
5531 // check if it's the PH, if it is return true.
5532 auto *user = use.getUser();
5533 // Consider only users inside the same function.
5534 if (user->getParent() != &F) {
5535 continue;
5536 }
5537 for (unsigned i = 0, numInputs = user->getNumInputs(); i < numInputs; i++) {
5538 // If the input is not overwritten we can continue.
5539 if (!user->isOverwrittenNthInput(i)) {
5540 continue;
5541 }
5542 auto input = use.getUser()->getNthInput(i);
5543 if (input.getNode() == PH) {
5544 return true;
5545 }
5546 }
5547 }
5548 return false;
5549 }
5550
5551 /// If \p PH is an input placeholder, \returns true.
isInput(const Placeholder * PH,const Function & F)5552 bool isInput(const Placeholder *PH, const Function &F) {
5553 // Check that the PH is the input to a saveNode or is used by a non saveNode.
5554 for (const auto &use : PH->getUsers()) {
5555 // Consider only users inside the same function.
5556 if (use.getUser()->getParent() != &F) {
5557 continue;
5558 }
5559 // Check if PH is an input to a saveNode.
5560 if (auto *save = dyn_cast<SaveNode>(use.getUser())) {
5561 auto input = save->getInput();
5562 // If the PH is not an input to the saveNode we keep looking.
5563 if (input.getNode() != PH) {
5564 continue;
5565 }
5566 }
5567 return true;
5568 }
5569 return false;
5570 }
5571
operator <<(llvm::raw_ostream & os,const Module & mod)5572 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module &mod) {
5573 mod.dump(os);
5574 return os;
5575 }
5576
operator <<(llvm::raw_ostream & os,const Module * mod)5577 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module *mod) {
5578 assert(mod != nullptr && "Null Pointer.");
5579 mod->dump(os);
5580 return os;
5581 }
5582
operator <<(llvm::raw_ostream & os,const Function & F)5583 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function &F) {
5584 F.dump(os);
5585 return os;
5586 }
5587
operator <<(llvm::raw_ostream & os,const Function * F)5588 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function *F) {
5589 assert(F != nullptr && "Null Pointer.");
5590 F->dump(os);
5591 return os;
5592 }
5593
isConvolutionSameAsFullyConnected(const ConvolutionNode * node,bool enforceInput1x1)5594 bool isConvolutionSameAsFullyConnected(const ConvolutionNode *node,
5595 bool enforceInput1x1) {
5596 bool isConv2D = (node->getInput().getType()->dims().size() == 4);
5597 if (!(isConv2D && node->getLayout() == ConvolutionLayout::NHWC &&
5598 !node->hasFusedActivation())) {
5599 return false;
5600 }
5601 auto filterDims = ShapeNHWC(node->getFilter().getType()->dims());
5602 ShapeHW kernels = ShapeHW(node->getKernels());
5603 ShapeHW strides = ShapeHW(node->getStrides());
5604 PaddingTLBR pads = PaddingTLBR(node->getPads());
5605 auto group = node->getGroup();
5606 auto dilation = node->getDilation();
5607
5608 bool isSame = (filterDims.h == 1) && (filterDims.w == 1);
5609 isSame &= (kernels.height == 1) && (kernels.width == 1);
5610 isSame &= (strides.height == 1) && (strides.width == 1);
5611 isSame &= (pads.top == 0) && (pads.left == 0) && (pads.bottom == 0) &&
5612 (pads.right == 0);
5613 isSame &= (group == 1);
5614 isSame &= (dilation == 1);
5615 if (enforceInput1x1) {
5616 auto inputDims = ShapeNHWC(node->getInput().getType()->dims());
5617 isSame &= (inputDims.h == 1) && (inputDims.w == 1);
5618 }
5619 return isSame;
5620 }
5621
isGemmSameAsFullyConnected(const GemmNode * node)5622 bool isGemmSameAsFullyConnected(const GemmNode *node) {
5623 NodeValue inpC = node->getC();
5624 return (node->getAlpha() == 1.0) && (node->getBeta() == 1.0) &&
5625 (inpC.getNode()) && (inpC.dims().size() == 1);
5626 }
5627
5628 } // namespace glow
5629