1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef GLOW_GRAPH_GRAPH_H
17 #define GLOW_GRAPH_GRAPH_H
18 
19 #include "glow/Base/Type.h"
20 #include "glow/Graph/Log.h"
21 #include "glow/Graph/Nodes.h"
22 #include "glow/Quantization/Base/Base.h"
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/StringSet.h"
27 #include "llvm/ADT/ilist.h"
28 #include "llvm/ADT/ilist_node.h"
29 
30 #include <list>
31 #include <vector>
32 
33 namespace glow {
34 class PlaceholderBindings;
35 
36 /// List of Types.
37 using TypesList = std::list<Type>;
38 /// Intrusive list of Nodes.
39 using NodesList = llvm::iplist<glow::Node>;
40 /// List of pointers to Nodes. The nodes are not owned by the list.
41 using NodesPtrList = std::list<glow::Node *>;
42 /// List of Functions.
43 using FunctionList = std::list<Function *>;
44 using ConstList = std::list<Constant *>;
45 using PlaceholderList = std::list<Placeholder *>;
46 using UnsignedArrayRef = llvm::ArrayRef<dim_t>;
47 /// Map from original Nodes to cloned Nodes.
48 using NodeMap = llvm::DenseMap<Node *, Node *>;
49 /// State of a function. This can be used to control optimizations which depend
50 /// on the state of the Function. This is a temporary workaround until GH Issue
51 /// #3213 is complete.
52 enum class FunctionState {
53   /// Indicates that the function has been created but not completely loaded.
54   FuncCreated,
55   /// Indicates that the function has been completely loaded.
56   FuncLoaded,
57 };
58 
59 /// Helper names for common tensor layouts.
60 #define ANY_LAYOUT "*"
61 
62 class Module final {
63   /// Stores the functions in the module.
64   FunctionList functions_;
65   /// A uniqued list of types. Types in this list can be equated by comparing
66   /// their addresses.
67   TypesList types_{};
68   /// Stores a list of unique Storage names that were used by the module at
69   /// some point.
70   llvm::StringSet<> usedStorageNames_{};
71   /// Stores a list of node names that were used by Functions of this module at
72   /// some point.
73   llvm::StringSet<> usedNodeNames_{};
74   /// Stores a list of node names that were present in the original model and
75   /// are good to be retained.
76   llvm::StringSet<> originalNames_{};
77   /// A list of constants that the Module owns.
78   ConstList constants_;
79   /// A list of placeholder nodes that the Module owns.
80   PlaceholderList placeholders_;
81   /// Deterministic PRNG used to initialize weights in this module.
82   PseudoRNG PRNG_;
83 
84   /// Module log context that stores all logs related to this module.
85   LogContext moduleLogCtx_{nullptr};
86 
87   /// Inserts the constant \p V to the list of constants.
88   Constant *addConstant(Constant *V);
89 
90   friend class Function;
91 
92 public:
93   Module() = default;
94 
95   ~Module();
96 
97   /// \returns the prefix part of the provided \p name. E.g. for an input
98   /// of "relu__2" returns "relu".
99   static std::string getPrefix(llvm::StringRef name);
100 
101   /// \returns unique legal name that's based on the string \p name. Legal
102   /// names are legal C identifiers in the form: "[a-zA-Z_][a-zA-Z0-9_]*".
103   /// The name may not be in \p stringTable or \p updateTable and will be
104   /// inserted into \p updateTable.
105   static llvm::StringRef uniqueName(llvm::StringRef name,
106                                     const llvm::StringSet<> &stringTable,
107                                     llvm::StringSet<> &updateTable,
108                                     const llvm::StringSet<> &originalNames);
109 
110   /// Registers a \p name as used by some Node in this module.
registerNodeName(llvm::StringRef name)111   void registerNodeName(llvm::StringRef name) {
112     // Don't care if it's already in the set.
113     usedNodeNames_.insert(name);
114   }
115 
116   /// Registers a \p name from the original model, good to be retained.
registerOriginalName(llvm::StringRef name)117   void registerOriginalName(llvm::StringRef name) {
118     // Don't care if it's already in the set.
119     if (name.size()) {
120       originalNames_.insert(name);
121     }
122   }
123 
124   /// \returns the pointer to list of original node names, good to be retained;
getOriginalNames()125   const llvm::StringSet<> *getOriginalNames() const { return &originalNames_; }
126 
127   /// Registers a name as used by a Storage node (Constant or Placeholder) in
128   /// this module.
registerStorageName(llvm::StringRef name)129   void registerStorageName(llvm::StringRef name) {
130     usedStorageNames_.insert(name);
131   }
132 
133   /// \returns whether there's a Storage node already registered with \p name.
hasStorageName(llvm::StringRef name)134   bool hasStorageName(llvm::StringRef name) {
135     return usedStorageNames_.count(name);
136   }
137 
138   /// Return a pointer to a uniqued type \p T.
139   TypeRef uniqueType(const Type &T);
140 
141   /// Return a pointer to a uniqued type \p T.
142   TypeRef uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims);
143 
144   /// Return a pointer to a uniqued type \p T.
145   TypeRef uniqueType(ElemKind elemTy, llvm::ArrayRef<dim_t> dims, float scale,
146                      int32_t offset);
147 
148   /// Return a pointer to a uniqued type \p T.
149   /// The new type is identical to \p T, with a new shape \p dims.
150   TypeRef uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims);
151 
152   /// The new type is identical to \p T, with a new shape \p dims and new \p
153   /// alignments.
154   TypeRef uniqueTypeWithNewShape(TypeRef T, llvm::ArrayRef<dim_t> dims,
155                                  llvm::ArrayRef<dim_t> alignments);
156 
157   /// Return a pointer to a uniqued type \p T.
158   /// The new type is identical to \p T, with a new shape and strides taken from
159   /// the type \p shapeType.
160   TypeRef uniqueTypeWithNewShape(TypeRef T, TypeRef shapeType);
161 
162   /// Return the void type.
163   TypeRef getVoidTy();
164 
165   /// \returns True if a function by the name \p name exists in the module.
166   bool hasFunction(llvm::StringRef name);
167   /// \returns the function with the name \p name, or nullptr if the function
168   /// does not exist.
169   Function *getFunction(llvm::StringRef name);
170   /// \returns a new function with the name \p name.
171   Function *createFunction(llvm::StringRef name);
172   /// \returns the list of Functions that the Module owns.
getFunctions()173   FunctionList &getFunctions() { return functions_; }
174 
getFunctions()175   const FunctionList &getFunctions() const { return functions_; }
176 
177   /// Clears out all Functions from \ref functions_.
178   void clearFunctions();
179 
180   /// \returns the list of types that the Module owns.
getTypes()181   const TypesList &getTypes() const { return types_; }
182 
183   /// Erase the constant \p N from the Module.
184   void eraseConstant(Constant *N);
185 
186   /// Erase the constant \p I from the Module.
187   void eraseConstant(ConstList::iterator I);
188 
189   /// Erase the placeholder \p I from the Module.
190   /// Note: we only provide an iterator version of this, as erasing Placeholders
191   /// is often unsafe.
192   void erasePlaceholder(PlaceholderList::iterator I);
193 
194   /// \returns a pointer to the first Constant with the name \p name or nullptr
195   /// if no node has this name.
196   Constant *getConstantByName(llvm::StringRef name) const;
197 
198   /// \returns the list of constants that the Module owns.
getConstants()199   ConstList &getConstants() { return constants_; }
200 
getConstants()201   const ConstList &getConstants() const { return constants_; }
202 
203   /// \returns the list of placeholders that the Module owns.
getPlaceholders()204   PlaceholderList &getPlaceholders() { return placeholders_; }
205 
getPlaceholders()206   const PlaceholderList &getPlaceholders() const { return placeholders_; }
207 
208   /// \returns a pointer to the placeholder with the name \p name or
209   /// nullptr if no placeholder has this name.
210   Placeholder *getPlaceholderByNameSlow(llvm::StringRef name) const;
211 
212   /// @name High-level Storage builders.
213   ///@{
214 
215   Placeholder *createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims,
216                                  llvm::StringRef name, bool isTrainable,
217                                  const std::string &layout = ANY_LAYOUT);
218 
219   Placeholder *createPlaceholder(TypeRef T, llvm::StringRef name,
220                                  bool isTrainable,
221                                  const std::string &layout = ANY_LAYOUT);
222 
223   Placeholder *createPlaceholder(ElemKind T, llvm::ArrayRef<dim_t> dims,
224                                  float scale, int32_t offset,
225                                  llvm::StringRef name, bool isTrainable,
226                                  const std::string &layout = ANY_LAYOUT);
227 
228   Constant *createConstant(TypeRef T, llvm::StringRef name,
229                            const std::string &layout = ANY_LAYOUT);
230 
231   Constant *createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims,
232                            llvm::StringRef name,
233                            const std::string &layout = ANY_LAYOUT);
234 
235   Constant *createConstant(ElemKind T, llvm::ArrayRef<dim_t> dims, float scale,
236                            int32_t offset, llvm::StringRef name,
237                            const std::string &layout = ANY_LAYOUT);
238 
239   Constant *createConstant(llvm::StringRef name, const Tensor &tensor,
240                            const std::string &layout = ANY_LAYOUT);
241 
242   Constant *createConstant(llvm::StringRef name, Tensor &&tensor,
243                            const std::string &layout = ANY_LAYOUT);
244 
245   ///@}
246 
247   /// Verify the correctness of the Module.
248   /// \returns true when the function is valid. False otherwise.
249   bool verify() const;
250 
251   /// Get the pseudo-random number generator used by this module.
getPRNG()252   PseudoRNG &getPRNG() { return PRNG_; }
253 
254   /// Dump a textual representation of the Module into default output stream.
255   void dump() const;
256 
257   /// Dump a textual representation of the Module to std::string.
258   std::string toString() const;
259 
260   /// Dump a textual representation of the Module into provided output stream.
261   void dump(llvm::raw_ostream &os) const;
262 
263   /// Dump a dotty graph that depicts the Module.
264   void dumpDAG();
265 
266   /// Dump a dotty graph that depicts the Module.
267   void dumpDAG(llvm::StringRef dotFilename);
268 
269   /// Dump a dotty graph that depicts the Module.
270   void dumpDAG(const char *dotFilename);
271 
272   /// Erase all of the functions from the module.
273   void eraseFunctions();
274 
275   /// Erase all the functions, Placeholders, Constants, etc.
276   void clear();
277 
278   /// Clone a module.
279   /// \returns a new module that is a copy of the current module.
280   Module *clone() const;
281 
282   /// Clone the current module into a user-provided module \p M.
283   /// \returns the user-provided module \p M that now contains a clone of the
284   /// current module.
285   Module *clone(Module *M) const;
286 
287   /// Strips payloads from constants. This is useful when
288   /// the Module will be kept around for metadata but we want to reduce memory
289   /// use. Unlike clear this leaves PHs and Constants in the module.
290   void strip();
291 
292   /// Erase a function \p F from the module.
293   void eraseFunction(Function *F);
294 
295   /// \Returns the size in bytes of data used by constants.
296   uint64_t getConstantsSize();
297 
298   /// \Returns the module log context.
getModuleLogContext()299   LogContext *getModuleLogContext() { return &moduleLogCtx_; };
300 
301   // Don't copy or move this class around.
302   // The destructor will wipe the functions leaving
303   // the original Module only dangling pointers.
304   Module(const Module &) = delete;
305   Module(Module &&) = delete;
306   Module &operator=(const PlaceholderBindings &) = delete;
307   Module &operator=(PlaceholderBindings &&) = delete;
308 };
309 
310 // Forward Declaration for verify's optional parameter
311 class Backend;
312 struct CompilationContext;
313 
314 /// Represents the compute graph.
315 class Function final : public IRContainer {
316   /// A list of nodes that the Function owns.
317   NodesList nodes_;
318 
319   /// A list of metadata PHs associated with the function.
320   std::vector<Placeholder *> metadataPlaceholders_;
321 
322   /// Stores a list of unique node names that were used by the module at some
323   /// point.
324   llvm::StringSet<> uniqueNodeNames_{};
325 
326   /// A reference to the owner of the function.
327   Module *parent_;
328 
329   /// The log context associated with this function.
330   std::shared_ptr<LogContext> logCtx_;
331 
332   /// The state of this function.
333   FunctionState state_;
334 
335 public:
336   Function(Module *parent, llvm::StringRef Name = {})
IRContainer(Name)337       : IRContainer(Name), parent_(parent), state_(FunctionState::FuncCreated) {
338     logCtx_ = std::make_shared<LogContext>(parent);
339     logCtx_->pushEvent(parent->getModuleLogContext()->getClonedScope());
340   }
341 
342   ~Function();
343 
344   /// Clear out \ref nodes_ and \ref uniqueNodeNames_.
345   void clear();
346 
347   /// Sets the state of the function.
setState(FunctionState state)348   void setState(FunctionState state) { state_ = state; }
349 
350   /// Gets the state of the function.
getState()351   FunctionState getState() { return state_; }
352 
getFilename()353   std::string getFilename() { return getName().rsplit('/').second.str(); }
354 
355   /// Return the log context.
getLogContext()356   std::shared_ptr<LogContext> getLogContext() { return logCtx_; }
357 
358   /// Add placeholder for metadata such as profiling.
addMetadataPlaceholder(Placeholder * PH)359   void addMetadataPlaceholder(Placeholder *PH) {
360     metadataPlaceholders_.push_back(PH);
361   }
362 
363   /// Get list of metadata placeholders.
getMetadataPlaceholders()364   const std::vector<Placeholder *> &getMetadataPlaceholders() const {
365     return metadataPlaceholders_;
366   }
367 
getParent()368   Module *getParent() { return parent_; }
369 
370   /// Perform ordering of nodes_ based on node's name.
371   /// This is to make sure that performing optimizations have a deterministic
372   /// behavior on the graphs which have the same ops but different ordering in
373   /// nodes_.
orderNodes()374   void orderNodes() {
375     nodes_.sort(
376         [](const Node &a, const Node &b) { return a.getName() < b.getName(); });
377   }
378 
379   /// Search the Module containing the function to gather and return a list of
380   /// placeholders that are used by the Function.
381   PlaceholderList findPlaceholders();
382   PlaceholderList findPlaceholders() const;
383 
384   /// Search the Module containing the function to gather and return a list of
385   /// constants that are used by the Function.
386   ConstList findConstants();
387   ConstList findConstants() const;
388 
getParent()389   const Module *getParent() const { return parent_; }
390 
391   /// Inserts the node \p N to the list of nodes, and returns the inserted node.
addNode(NodeTy * N)392   template <class NodeTy> NodeTy *addNode(NodeTy *N) {
393     N->setName(Module::uniqueName(N->getName(), parent_->usedStorageNames_,
394                                   uniqueNodeNames_, parent_->originalNames_));
395     parent_->registerNodeName(N->getName());
396     nodes_.push_back(N);
397 
398     // Log the node creation.
399     logCtx_->logNodeCreation(*N);
400 
401     return N;
402   }
403 
404   /// Take ownership of \p N by removing it from its original Function, add it
405   /// to the current Function, and also unique its name.
takeOwnershipOfNode(Node * N)406   void takeOwnershipOfNode(Node *N) {
407     N->getParent()->getNodes().remove(N);
408     N->setName(Module::uniqueName(N->getName(), parent_->usedStorageNames_,
409                                   uniqueNodeNames_, parent_->originalNames_));
410     parent_->registerNodeName(N->getName());
411     nodes_.push_back(N);
412   }
413 
414   /// Get the pseudo-random number generator used by this module.
getPRNG()415   PseudoRNG &getPRNG() { return getParent()->getPRNG(); }
416 
417   /// @name High-level, operation-level IRBuilder.
418   ///@{
419 
420   /// Creates a PadNode with the given \p name and output type \p outTy which
421   /// pads the given \p input with the explicit pads \p pads according to the
422   /// padding mode \p mode and with the given value \p value. The padding mode
423   /// \p mode is one of enumeration values from \ref PaddingMode. For an input
424   /// with N dimensions (rank N) the \p pads must be a vector with 2*N values
425   /// with the following format:
426   /// pads = [pad_before(D1), pad_before(D2), ..., pad_before(DN),
427   ///         pad_after (D1), pad_after (D2), ..., pad_after (DN)].
428   /// The mode PaddingMode::CONSTANT pads the input using the constant value
429   /// \p value and currently is the only mode supported.
430   PadNode *createPad(llvm::StringRef name, NodeValue input, TypeRef outTy,
431                      unsigned_t mode, llvm::ArrayRef<int> pads, float value);
432 
433   /// Creates a ConvolutionNode with the given \p name which convolves the 4D
434   /// \p input with \p filter and \bias. \p kernels defines the size of the
435   /// height and width dimensions of the filters. \p strides defines the number
436   /// of steps to take in the input for each output cell. \p pads defines how
437   /// many zero padding cells should be added to the input during convolution.
438   /// \p group defines the number of groups the input and output channels should
439   /// be divided into and convolved separately. \p dilation defines factor by
440   /// which gap between 2 neighboring kernel elements is expanded along each
441   /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW.
442 
443   ConvolutionNode *
444   createConv(llvm::StringRef name, NodeValue input, NodeValue filter,
445              NodeValue bias, TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
446              llvm::ArrayRef<unsigned_t> strides,
447              llvm::ArrayRef<unsigned_t> pads, unsigned_t group,
448              unsigned_t dilation = 1,
449              ConvolutionLayout layout = ConvolutionLayout::NHWC);
450 
451   /// Creates a ConvolutionNode with the given \p name which convolves the 4D
452   /// \p input with \p filter and \bias. \p kernel defines the size of the
453   /// height and width dimensions of the filters. \p stride defines the number
454   /// of steps to take in the input for each output cell. \p pad defines how
455   /// many zero padding cells should be added to the input during convolution.
456   /// \p group defines the number of groups the input and output channels should
457   /// be divided into and convolved separately. \p dilation defines factor by
458   /// which gap between 2 neighboring kernel elements is expanded along each
459   /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW.
460 
461   ConvolutionNode *
462   createConv(llvm::StringRef name, NodeValue input, NodeValue filter,
463              NodeValue bias, TypeRef outTy, unsigned_t kernel,
464              unsigned_t stride, unsigned_t pad, unsigned_t group,
465              unsigned_t dilation = 1,
466              ConvolutionLayout layout = ConvolutionLayout::NHWC);
467 
468   /// Creates a Convolution3DNode with the given \p name which convolves the 5D
469   /// \p input with \p filter and \bias. \p kernels defines the size of the
470   /// height, width, and depth dimensions of the filters. \p strides defines the
471   /// the number of steps to take in the input for each output cell. \p pads
472   /// defines how many zero padding cells should be added to the input during
473   /// convolution. \p group defines the number of groups the input and output
474   /// channels should be divided into and convolved separately. \p outTy defines
475   /// the type of the output of the 3d convolution.
476   Convolution3DNode *createConv3D(llvm::StringRef name, NodeValue input,
477                                   NodeValue filter, NodeValue bias,
478                                   TypeRef outTy,
479                                   llvm::ArrayRef<unsigned_t> kernels,
480                                   llvm::ArrayRef<unsigned_t> strides,
481                                   llvm::ArrayRef<unsigned_t> pads,
482                                   unsigned_t group);
483 
484   /// Creates a Convolution3DNode with the given \p name which convolves the 5D
485   /// \p input with \p filter and \bias. \p kernel defines the size of the
486   /// height, width, and depth dimensions of the filters. \p stride defines the
487   /// the number of steps to take in the input for each output cell. \p pad
488   /// defines how many zero padding cells should be added to the input during
489   /// convolution. \p group defines the number of groups the input and output
490   /// channels should be divided into and convolved separately. \p outTy defines
491   /// the type of the output of the 3d convolution.
492   Convolution3DNode *createConv3D(llvm::StringRef name, NodeValue input,
493                                   NodeValue filter, NodeValue bias,
494                                   TypeRef outTy, unsigned_t kernel,
495                                   unsigned_t stride, unsigned_t pad,
496                                   unsigned_t group);
497 
498   /// Creates a ChannelwiseQuantizedConvolutionNode with the given \p name which
499   /// convolves the 4D/5D \p input with \p filter and \p bias. \p filterScales
500   /// and \p filterOffsets provide individual quantization parameters for each
501   /// filter group in \p filter while \p biasScales and \p biasOffsets provide
502   /// individual quantization parameters for each bias element corresponding to
503   /// each output channel. \p kernels defines the size of the height and width
504   /// dimensions of the filters. \p strides defines the number of steps to take
505   /// in the input for each output cell. \p pads defines how many zero padding
506   /// cells should be added to the input during convolution. \p group defines
507   /// the number of groups the input and output channels should be divided into
508   /// and convolved separately. \p dilation defines the filter dilation.
509   /// This function is flexible and has the following features:
510   /// - it can be provided with a floating-point \p filter and the function will
511   ///   quantize automatically the filter channelwise using the given schema
512   ///   \p schema and type \p filterElemQTy.
513   /// - it can be provided with a floating-point \p bias and the function will
514   ///   quantize automatically the bias channelwise using the given schema
515   ///   \p schema and type \p biasElemQTy.
516   /// - if \p filter is floating-point and \p filterScales or \p filterOffsets
517   ///   are not provided then this function will derive them automatically.
518   /// - if \p filter is quantized then \p filterScales or \p filterOffsets are
519   ///   mandatory.
520   /// - if \p bias is floating-point and \p biasScales or \p biasOffsets are not
521   ///   provided then this function will derive them automatically.
522   /// - if \p bias is quantized  and \p biasScales or \p biasOffsets are not
523   ///   provided then this function will assume the implicit parameters
524   ///   biasScales[i] = inputScale * filterScales[i] and biasOffsets[i] = 0.
525   ///   To be noted that this case can handle safely only INT32 bias data type
526   ///   because for INT8 type the bias will almost certainly be saturated.
527   /// This function will only quantize the filter if \p quantizeFilter is set
528   /// to true and will only quantize the bias if \p quantizeBias is set to true
529   /// such that a floating-point filter/bias can be attached to the node as-is
530   /// without any modifications in order for the backends to perform their own
531   /// custom quantization later if desired.
532   /// This function requires \p filter and \p bias operands to be constants.
533   ChannelwiseQuantizedConvolutionNode *createChannelwiseQuantizedConv(
534       llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
535       NodeValue filterScales, NodeValue filterOffsets, NodeValue biasScales,
536       NodeValue biasOffsets, TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
537       llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
538       unsigned_t group, unsigned_t dilation = 1, bool quantizeFilter = true,
539       bool quantizeBias = true,
540       quantization::Schema schema = quantization::Schema::Asymmetric,
541       ElemKind filterElemQTy = ElemKind::Int8QTy,
542       ElemKind biasElemQTy = ElemKind::Int32QTy);
543 
544   /// Creates a ConvTransposeNode with the given \p name which does transposed
545   /// convolution of the 4D \p input with \p filter and \bias. \p kernels define
546   /// the size of the height and width dimensions of the filters. \p strides
547   /// define the number of steps to take in the input for each output cell.
548   /// \p pads define how many zero padding cells should be added to the input
549   /// during convolution. \p group defines the number of groups the input and
550   /// output channels should be divided into and convolved separately.
551   ConvTransposeNode *createConvTranspose(
552       llvm::StringRef name, NodeValue input, NodeValue filter, NodeValue bias,
553       TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
554       llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
555       unsigned_t group, unsigned_t dilation = 1);
556 
557   /// Creates a createConvTransposeNode with the given \p name which does
558   /// transposed convolution of the 4D \p input with \p filter and \bias. \p
559   /// kernel defines the size of the height and width dimensions of the filters.
560   /// \p stride defines the number of steps to take in the input for each output
561   /// cell. \p pad defines how many zero padding cells should be added to the
562   /// input during convolution. \p group defines the number of groups the input
563   /// and output channels should be divided into and convolved separately.
564   ConvTransposeNode *createConvTranspose(llvm::StringRef name, NodeValue input,
565                                          NodeValue filter, NodeValue bias,
566                                          TypeRef outTy, unsigned_t kernel,
567                                          unsigned_t stride, unsigned_t pad,
568                                          unsigned_t group,
569                                          unsigned_t dilation = 1);
570 
571   /// Creates and \returns a ConvertTo Node with name \p name of \p input to
572   /// output type \p outTy.
573   ConvertToNode *createConvertTo(llvm::StringRef name, NodeValue input,
574                                  TypeRef outTy);
575 
576   /// Creates and \returns a ConvertTo Node with name \p name of \p input to
577   /// output ElemKind \p k.
578   ConvertToNode *createConvertTo(llvm::StringRef name, NodeValue input,
579                                  ElemKind k);
580 
581   MaxPoolNode *createMaxPool(llvm::StringRef name, NodeValue input,
582                              llvm::ArrayRef<unsigned_t> kernels,
583                              llvm::ArrayRef<unsigned_t> strides,
584                              llvm::ArrayRef<unsigned_t> pads,
585                              ElemKind elemTyAMT = ElemKind::Int64ITy,
586                              ConvolutionLayout layout = NHWC);
587 
588   MaxPoolNode *createMaxPool(llvm::StringRef name, NodeValue input,
589                              unsigned_t kernel, unsigned_t stride,
590                              unsigned_t pad,
591                              ElemKind elemTyAMT = ElemKind::Int64ITy,
592                              ConvolutionLayout layout = NHWC);
593 
594   AvgPoolNode *createAvgPool(llvm::StringRef name, NodeValue input,
595                              llvm::ArrayRef<unsigned_t> kernels,
596                              llvm::ArrayRef<unsigned_t> strides,
597                              llvm::ArrayRef<unsigned_t> pads,
598                              ConvolutionLayout layout = NHWC);
599 
600   AvgPoolNode *createAvgPool(llvm::StringRef name, NodeValue input,
601                              TypeRef outTy, llvm::ArrayRef<unsigned_t> kernels,
602                              llvm::ArrayRef<unsigned_t> strides,
603                              llvm::ArrayRef<unsigned_t> pads,
604                              ConvolutionLayout layout = NHWC);
605 
606   AvgPoolNode *createAvgPool(llvm::StringRef name, NodeValue input,
607                              unsigned_t kernel, unsigned_t stride,
608                              unsigned_t pad, ConvolutionLayout layout = NHWC);
609 
610   /// Creates and \returns an AdaptiveAvgPool node with \p name, \p input, and
611   /// \p outTy. The AdaptiveAvgPoolNode will perform average pooling over the
612   /// input so that the result is of the shape specified by \p outTy.
613   AdaptiveAvgPoolNode *createAdaptiveAvgPool(llvm::StringRef name,
614                                              NodeValue input, TypeRef outTy);
615 
616   /// Creates and \returns a General Matrix Multiplication (Gemm) node with
617   /// given \p name which computes Y = alpha * A * B + beta * C. The operands
618   /// \p A and \p B are 2D matrices, the \p C operand is an optional 1D or 2D
619   /// matrix (broadcastable to the size of Y) and \p alpha and \p beta are float
620   /// scalars. The \p C operand is optional, if nullptr is given then it is not
621   /// used. If \p transposeA or \p transposeB is true then \p A or \p B is
622   /// additionally transposed prior to matrix multiplication.
623   /// If the output shape of Y is [M,N] then:
624   /// - The shape of \p A must be [M,K] or [K,M] (if transposed).
625   /// - The shape of \p B must be [K,N] or [N,K] (if transposed).
626   /// - The shape of \p C must be [N] (if 1D) or [M,N] (if 2D).
627   GemmNode *createGemm(llvm::StringRef name, NodeValue A, NodeValue B,
628                        NodeValue C = nullptr, float alpha = 1.0,
629                        float beta = 1.0, bool transposeA = false,
630                        bool transposeB = false);
631 
632   GemmNode *createGemm(llvm::StringRef name, TypeRef outTy, NodeValue A,
633                        NodeValue B, NodeValue C = nullptr, float alpha = 1.0,
634                        float beta = 1.0, bool transposeA = false,
635                        bool transposeB = false);
636 
637   /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
638   /// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened
639   /// along \p axis. Note, output type and outputDepth are inferred based on
640   /// the input types.
641   FullyConnectedNode *createFullyConnected(llvm::StringRef name,
642                                            NodeValue input, Storage *W,
643                                            Storage *B, unsigned_t axis = 1);
644 
645   /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
646   /// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened
647   /// along \p axis. Note, output type and outputDepth are inferred based on
648   /// the input types.
649   FullyConnectedNode *createFullyConnected(llvm::StringRef name,
650                                            NodeValue input, NodeValue W,
651                                            NodeValue B, unsigned_t axis = 1);
652 
653   /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
654   /// \p W, bias \p B, and \p outTy. If \p input is not 2 dimensional then it is
655   /// flattened along \p axis. Note, outputDepth is inferred based on \p outTy.
656   FullyConnectedNode *createFullyConnected(llvm::StringRef name,
657                                            NodeValue input, NodeValue W,
658                                            NodeValue B, TypeRef outTy,
659                                            unsigned_t axis = 1);
660 
661   /// Create a row-wise quantized fully connected node. This node is only used
662   /// in quantization. Args \p input and \p B are quantized in regular way, \p W
663   /// is the constant weights and is row-wise quantized using the given \p
664   /// scales and \p offsets. The output is quantized in the regular way, and its
665   /// type \p outTy is a quantized type.
666   RowwiseQuantizedFullyConnectedNode *createRowwiseQuantizedFullyConnected(
667       llvm::StringRef name, NodeValue input, Constant *W, Constant *scales,
668       Constant *offsets, NodeValue B, TypeRef outTy);
669 
670   /// Create a row-wise quantized fully connected node. This node is only used
671   /// in quantization. Args \p input and \p B are quantized in regular way, \p W
672   /// is the constant weights and will be row-wise quantized during node
673   /// creation time. The output is quantized in the regular way, and its type
674   /// \p outTy is a quantized type. if \p transposeWeight is true, \p W need to
675   /// be transposed first.
676   RowwiseQuantizedFullyConnectedNode *createRowwiseQuantizedFullyConnected(
677       llvm::StringRef name, NodeValue input, Constant *W, NodeValue B,
678       TypeRef outTy, quantization::Schema schema, bool transposeWeight = false);
679 
680   /// Implement an operation that computes the row-wise dot product of its
681   /// inputs. Consequently, \p X and \p Y must be either 1D or 2D tensors. This
682   /// lowered to a Mul node, and is followed by a BatchedReduceAdd if \p X and
683   /// \p Y are 2D. \returns either the Mul or BatchedReduceAdd node.
684   Node *createDotProduct(llvm::StringRef name, NodeValue X, NodeValue Y);
685 
686   /// Create a node that computes the pairwise dot product of \p inputs, which
687   /// must be a list of 2D tensors with identical shape. \returns the
688   /// BatchedPairwiseDotProductNode.
689   BatchedPairwiseDotProductNode *
690   createBatchedPairwiseDotProduct(llvm::StringRef name,
691                                   llvm::ArrayRef<NodeValue> inputs);
692 
693   /// Create a node that implements the elementwise linear operator. \p X is
694   /// 2D and \p w and \p b are 1D. \p w and \p b are broadcasted to match the
695   /// shape of \p X and then the output is computed by multiplying \p X and
696   /// broadcasted \p w and adding broadcasted \p b. \returns the
697   /// ElementwiseLinearNode. \p axis indicates the axis of the inputs (the other
698   /// axis of \p X is assumed to be the batch index).
699   Node *createElementwiseLinear(llvm::StringRef name, NodeValue X, NodeValue w,
700                                 NodeValue b, unsigned axis);
701 
702   /// Create a ReLU node with the given \p name and \p input.
703   /// Result type will be implicitly set based on the \p input type.
704   ReluNode *createRELU(llvm::StringRef name, NodeValue input);
705 
706   /// Create a ReLU node with the given \p name, \p input and
707   /// output type \p outTy.
708   ReluNode *createRELU(llvm::StringRef name, NodeValue input, TypeRef outTy);
709 
710   /// Create a series of nodes representing a GeLU with the given \p name and \p
711   /// input. Result type will be implicitly set based on the \p input type.
712   Node *createGELU(llvm::StringRef name, NodeValue input);
713 
714   /// Create a PReLU node with the given \p name, \p input and  \p slope.
715   /// Result type will be implicitly set based on the \p input type.
716   PReluNode *createPRELU(llvm::StringRef name, NodeValue input,
717                          NodeValue slope);
718 
719   /// Create a PReLU node with the given \p name, \p input, \p slope and
720   /// output type \p outTy.
721   PReluNode *createPRELU(llvm::StringRef name, NodeValue input, NodeValue slope,
722                          TypeRef outTy);
723 
724   /// Create a Sigmoid node with the given \p name, \p input and
725   /// output type \p outTy.
726   SigmoidNode *createSigmoid(llvm::StringRef name, TypeRef outTy,
727                              NodeValue input);
728 
729   /// Create a Sigmoid node with the given \p name and \p input.
730   /// Result type will be implicitly set based on the \p input type.
731   SigmoidNode *createSigmoid(llvm::StringRef name, NodeValue input);
732 
733   /// Create a Swish node with the given \p name and \p input.
734   /// Result type will be implicitly set based on the \p input type.
735   SwishNode *createSwish(llvm::StringRef name, NodeValue input);
736 
737   /// Create a Tanh node with the given \p name, \p input and
738   /// output type \p outTy.
739   TanhNode *createTanh(llvm::StringRef name, TypeRef outTy, NodeValue input);
740 
741   /// Create a Tanh node with the given \p name and \p input.
742   /// Result type will be implicitly set based on the \p input type.
743   TanhNode *createTanh(llvm::StringRef name, NodeValue input);
744 
745   /// Create an Exp node with \p name, which calculates element-wise
746   /// exponential of \p input.
747   ExpNode *createExp(llvm::StringRef name, NodeValue input);
748 
749   /// Create an Exp node with \p name with output type \p outTy, which
750   /// calculates element-wise exponential of \p input.
751   ExpNode *createExp(llvm::StringRef name, TypeRef outTy, NodeValue input);
752 
753   /// Create a Log node with \p name, which calculates element-wise natural log
754   /// of \p input, with output type \p outTy.
755   LogNode *createLog(llvm::StringRef name, NodeValue input,
756                      TypeRef outTy = nullptr);
757 
758   /// \returns a LogitNode with \p name given \p input and \p eps.
759   LogitNode *createLogit(llvm::StringRef name, NodeValue input, float eps);
760 
761   SoftMaxNode *createSoftMax(llvm::StringRef name, NodeValue input,
762                              NodeValue selected, TypeRef outTy = nullptr,
763                              float beta = 1.0);
764 
765   CrossEntropyLossNode *createCrossEntropyLoss(llvm::StringRef name,
766                                                NodeValue input,
767                                                NodeValue labels);
768 
769   RegressionNode *createRegression(llvm::StringRef name, NodeValue input,
770                                    NodeValue expected);
771 
772   /// Creates a node, which computes sigmoid cross entropy between two inputs.
773   SigmoidCrossEntropyWithLogitsNode *
774   createSigmoidCrossEntropyWithLogits(llvm::StringRef name, NodeValue logits,
775                                       NodeValue targets);
776 
777   ReshapeNode *createReshape(llvm::StringRef name, NodeValue input,
778                              UnsignedArrayRef shape,
779                              llvm::StringRef layout = ANY_LAYOUT);
780 
781   TransposeNode *createTranspose(llvm::StringRef name, NodeValue input,
782                                  llvm::ArrayRef<unsigned_t> shuffle,
783                                  const std::string &layout = ANY_LAYOUT);
784 
785   /// Create a node with the name \p name which flips (reorders) the elements
786   /// of the input \p input along the given axis \p axis.
787   FlipNode *createFlip(llvm::StringRef name, NodeValue input, unsigned_t axis);
788 
789   /// Create a series of nodes that implement a Broadcast operation. The \p
790   /// input Tensor is broadcasted based on \p newShape and along the \p axis,
791   /// which defines the offset from the leading dimension under which
792   /// broadcasting is performed.
793   Node *createBroadcast(llvm::StringRef name, NodeValue input,
794                         UnsignedArrayRef newShape, unsigned_t axis);
795 
796   /// Create concat node which concatenates input tensors along \p dimension.
797   ConcatNode *createConcat(llvm::StringRef name,
798                            llvm::ArrayRef<NodeValue> inputs,
799                            unsigned_t dimension);
800 
801   /// Create concat node with the given return type \p outTy.
802   ConcatNode *createConcat(llvm::StringRef name,
803                            llvm::ArrayRef<NodeValue> inputs,
804                            unsigned_t dimension, TypeRef outTy);
805 
806   /// Create a quantized TileNode with \p name, \p input, \p tiles, and \p axis.
807   /// For example, an input tensor {{1,2,3,4}} of dimension 1x4 with tiles = 2
808   /// and axis = 0 would result in an output tensor {{1,2,3,4}, {1,2,3,4}} of
809   /// dimension 2x4.
810   TileNode *createTile(llvm::StringRef name, NodeValue input, unsigned_t tiles,
811                        unsigned_t axis, TypeRef outTy = nullptr);
812 
813   /// Create an insert tensor node \p name, which inserts \p small into \p big
814   /// at offset into big \p start \p count times along \p axis.
815   InsertTensorNode *createInsertTensor(llvm::StringRef name, NodeValue big,
816                                        NodeValue small,
817                                        llvm::ArrayRef<dim_t> start,
818                                        unsigned_t count = 1,
819                                        unsigned_t axis = 0);
820 
821   /// Create a slice node \p name with the given starting points for each
822   /// dimension \p begin and end points \p end (exclusive).
823   SliceNode *createSlice(llvm::StringRef name, NodeValue input,
824                          UnsignedArrayRef begin, UnsignedArrayRef end);
825 
826   /// Create a slice node with the given starting point for each dimension.
827   /// End points will be calculated based on the output type during execution.
828   SliceNode *createSlice(llvm::StringRef name, NodeValue input,
829                          llvm::ArrayRef<dim_t> start, TypeRef outTy);
830 
831   /// Shuffles dimension number \p kernel. Suppose original size is D. It will
832   /// be represented as groupX(D/group) matrix, transposed and concatenated back
833   /// to size D. For example, shuffle of {1, 2, 3, 4, 5, 6} with \p group = 2 is
834   /// {1, 4, 2, 5, 3, 6}
835   Node *createChannelShuffle(llvm::StringRef name, NodeValue input,
836                              size_t group, size_t kernel);
837 
838   /// Computes the indices of the max elements of the input tensor along the
839   /// provided \p axis. The resulted tensor has the same rank as the input if \p
840   /// keepDims equal 1. If \p keepdims equals 0, the resulted tensor has the
841   /// reduced dimension pruned. The type of the output tensor is \p elemTy.
842   ArgMaxNode *createArgMax(llvm::StringRef name, NodeValue input,
843                            unsigned_t axis, bool keepDims,
844                            ElemKind elemTy = ElemKind::Int64ITy);
845 
846   /// Computes the indices of the min elements of the input tensor along the
847   /// provided \p axis. The resulted tensor has the same rank as the input if \p
848   /// keepDims equal 1. If \p keepdims equals 0, the resulted tensor has the
849   /// reduced dimension pruned. The type of the output tensor is \p elemTy.
850   ArgMinNode *createArgMin(llvm::StringRef name, NodeValue input,
851                            unsigned_t axis, bool keepDims,
852                            ElemKind elemTy = ElemKind::Int64ITy);
853 
854   /// Removes single-dimensional entries from the shape of a tensor. The
855   /// parameter \p axes is a list of positive integers, indicating the
856   /// dimensions to squeeze. Impelmented as a single ReshapeNode. This is the
857   /// opposite of ExpandDims.
858   /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#squeeze
859   ReshapeNode *createSqueeze(llvm::StringRef name, NodeValue input,
860                              llvm::ArrayRef<dim_t> axes);
861 
862   /// Add single-dimensional entries to the shape of the \p input tensor at
863   /// locations in \p axes. \p axes is listed as seen in the output tensor.
864   /// Implemented as a single ReshapeNode. This is the opposite of Squeeze.
865   ReshapeNode *createExpandDims(llvm::StringRef name, NodeValue input,
866                                 llvm::ArrayRef<dim_t> axes);
867 
868   /// Flattens the input tensor into a 2D matrix. If input tensor has shape
869   /// (d_0, d_1, ... d_n) then the output will have shape:
870   /// (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X d_n).
871   ReshapeNode *createFlatten(llvm::StringRef name, NodeValue input,
872                              unsigned_t axis);
873 
874   /// Create \p outputNum slice nodes of \p input. Slices happen along dimension
875   /// number \p axis. Array \p split defines lengths of slices. If \p split is
876   /// empty, \p input is split to equal sized parts.
877   void createSplit(llvm::StringRef name, NodeValue input, unsigned_t outputNum,
878                    unsigned_t axis, llvm::ArrayRef<dim_t> split,
879                    std::vector<SliceNode *> &outputs);
880 
881   BatchNormalizationNode *
882   createBatchNormalization(llvm::StringRef name, NodeValue input,
883                            NodeValue beta, NodeValue scale, NodeValue mean,
884                            NodeValue var, unsigned_t channelIdx = 0,
885                            float epsilon = 1e-5, float momentum = 0.9);
886 
887   /// Creates and \returns a LayerNormalizationNode that computes the layer
888   /// normalization of the inner most layers of \p input based on the shape of
889   /// \p scale and \p bias. \p epsilon is a small perterbation used to avoid
890   /// division by 0 during normalization.
891   LayerNormalizationNode *createLayerNormalization(llvm::StringRef name,
892                                                    NodeValue input,
893                                                    NodeValue scale,
894                                                    NodeValue bias,
895                                                    float epsilon = 1e-5);
896 
897   /// Bucketizes the input tensor based on monotonically increasing \p
898   /// boundaries for each value in \p input. For each value x in input, the
899   /// operator \returns index i given boundaries[i-1] < x <= boundaries[i]. If
900   /// the value x is beyond the bounds of boundaries, 0 or len(boundaries) is
901   /// returned as appropriate.
902   BucketizeNode *createBucketizeNode(llvm::StringRef name, NodeValue input,
903                                      llvm::ArrayRef<float> boundaries);
904 
905   LocalResponseNormalizationNode *createLocalResponseNormalization(
906       llvm::StringRef name, NodeValue input, unsigned_t halfWindowSize = 2,
907       float alpha = 1e-4, float beta = 0.75, float k = 2.0);
908 
909   /// Create a ModuloNode which performs the modulo operation elementwise on the
910   /// \p input such that each element in the output is equal to the
911   /// corresponding element in the input modulo \p divisor. If \p
912   /// signFollowDivisor is true then any negative elements in the output will
913   /// have divisor added to their final values.
914   ModuloNode *createModulo(llvm::StringRef name, NodeValue input,
915                            int64_t divisor, bool signFollowDivisor = false);
916 
917   /// Create a logical NOT node with name \p name and input \p input.
918   NotNode *createNot(llvm::StringRef name, NodeValue input);
919 
920 #define UNARY_ARITHMETIC_FUN_DECL(NODE_NAME_)                                  \
921   NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, NodeValue input); \
922   NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, TypeRef Ty,       \
923                                        NodeValue input);
924   UNARY_ARITHMETIC_FUN_DECL(Abs)
925   UNARY_ARITHMETIC_FUN_DECL(Neg)
926   UNARY_ARITHMETIC_FUN_DECL(Floor)
927   UNARY_ARITHMETIC_FUN_DECL(Ceil)
928   UNARY_ARITHMETIC_FUN_DECL(Round)
929   UNARY_ARITHMETIC_FUN_DECL(Sqrt)
930   UNARY_ARITHMETIC_FUN_DECL(Rsqrt)
931   UNARY_ARITHMETIC_FUN_DECL(Reciprocal)
932   UNARY_ARITHMETIC_FUN_DECL(Sin)
933   UNARY_ARITHMETIC_FUN_DECL(Cos)
934 #undef UNARY_ARITHMETIC_FUN_DECL
935 
936 #define ARITHMETIC_FUN_DECL(NODE_NAME_)                                        \
937   NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, NodeValue LHS,    \
938                                        NodeValue RHS);                         \
939   NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, TypeRef Ty,       \
940                                        NodeValue LHS, NodeValue RHS);
941   ARITHMETIC_FUN_DECL(Add);
942   ARITHMETIC_FUN_DECL(Mul);
943   ARITHMETIC_FUN_DECL(Sub);
944   ARITHMETIC_FUN_DECL(Div);
945   ARITHMETIC_FUN_DECL(Max);
946   ARITHMETIC_FUN_DECL(Min);
947   ARITHMETIC_FUN_DECL(CmpEQ);
948   ARITHMETIC_FUN_DECL(CmpNEQ);
949   ARITHMETIC_FUN_DECL(CmpLT);
950   ARITHMETIC_FUN_DECL(CmpLTE);
951   ARITHMETIC_FUN_DECL(And);
952   ARITHMETIC_FUN_DECL(Or);
953   ARITHMETIC_FUN_DECL(Xor);
954   ARITHMETIC_FUN_DECL(Pow);
955 #undef ARITHMETIC_FUN_DECL
956 
957   std::vector<NodeValue>
958   broadcastInputs(int axis, const llvm::ArrayRef<NodeValue> inputs);
959 
960   template <class T, class U>
961   using enable_if_same_t = std::enable_if<std::is_same<T, U>::value, U>;
962 
963 #define BROADCAST_FUNC_COMMON_CODE(NUM_INPUTS)                                 \
964   constexpr size_t numInputs = sizeof...(Args);                                \
965   static_assert(numInputs == NUM_INPUTS,                                       \
966                 "Invalid input passed in to commonCreateBroadcast.");          \
967   std::vector<NodeValue> inputs = broadcastInputs(axis, {inputArgs...});
968 
969 #define DECLARE_BROADCAST_NODE(NODE_NAME, NUM_INPUTS)                          \
970   template <class T, class... Args>                                            \
971   typename enable_if_same_t<T, NODE_NAME##Node>::type *                        \
972   createNodeWithBroadcast(const std::string &name, int axis,                   \
973                           Args &&... inputArgs) {                              \
974     BROADCAST_FUNC_COMMON_CODE(NUM_INPUTS)                                     \
975     return create##NODE_NAME(name, inputs[0].getType(), inputs[0], inputs[1]); \
976   }
977 
978   /// Template function that creates a node and normalizes its input shapes
979   /// with the use of BroadCast nodes. If axis is -1, it calculates it
980   /// automatically for multi directional broadcast.
981   DECLARE_BROADCAST_NODE(Mul, /* NUM_INPUTS */ 2)
982   DECLARE_BROADCAST_NODE(Div, /* NUM_INPUTS */ 2)
983   DECLARE_BROADCAST_NODE(Add, /* NUM_INPUTS */ 2)
984   DECLARE_BROADCAST_NODE(Sub, /* NUM_INPUTS */ 2)
985 
986 #define DECLARE_CMP_BROADCAST_NODE(NODE_NAME)                                  \
987   template <class T, class... Args>                                            \
988   typename enable_if_same_t<T, NODE_NAME##Node>::type *                        \
989   createNodeWithBroadcast(const std::string &name, int axis,                   \
990                           Args &&... inputArgs) {                              \
991     BROADCAST_FUNC_COMMON_CODE(2)                                              \
992     return create##NODE_NAME(name, inputs[0], inputs[1]);                      \
993   }
994 
995   /// Template function that creates a node and normalizes its input shapes
996   /// with the use of BroadCast nodes. If axis is -1, it calculates it
997   /// automatically for multi directional broadcast.
DECLARE_CMP_BROADCAST_NODE(CmpLT)998   DECLARE_CMP_BROADCAST_NODE(CmpLT)
999   DECLARE_CMP_BROADCAST_NODE(CmpEQ)
1000   DECLARE_CMP_BROADCAST_NODE(CmpLTE)
1001   DECLARE_CMP_BROADCAST_NODE(Min)
1002   DECLARE_CMP_BROADCAST_NODE(Max)
1003 
1004   /// Template function that creates a node and normalizes its input shapes
1005   /// with the use of BroadCast nodes. If axis is -1, it calculates it
1006   /// automatically for multi directional broadcast.
1007   template <class T, class... Args>
1008   typename enable_if_same_t<T, SelectNode>::type *
1009   createNodeWithBroadcast(const std::string &name, int axis,
1010                           Args &&... inputArgs) {
1011     BROADCAST_FUNC_COMMON_CODE(3)
1012     return createSelect(name, inputs[1].getType(), inputs[0], inputs[1],
1013                         inputs[2]);
1014   }
1015 
1016 #undef BROADCAST_FUNC_COMMON_CODE
1017 #undef DECLARE_BROADCAST_NODE
1018 #undef DECLARE_CMP_BROADCAST_NODE
1019 #undef BROADCAST_FUNC_COMMON_CODE
1020 
1021   /// Create an element-wise GREATER THAN comparison between \p LHS and \p RHS
1022   /// by creating a CmpLTNode with given \p name and swapped inputs.
1023   CmpLTNode *createCmpGT(llvm::StringRef name, NodeValue LHS, NodeValue RHS);
1024 
1025   /// Create an element-wise GREATER THAN or EQUAL comparison between \p LHS and
1026   /// \p RHS by creating a CmpLTENode with given \p name and swapped inputs.
1027   CmpLTENode *createCmpGTE(llvm::StringRef name, NodeValue LHS, NodeValue RHS);
1028 
1029   /// Create a MulNode with given \p name which multiplies \p input with itself
1030   /// to produce an equivalent Square node.
1031   MulNode *createSquare(llvm::StringRef name, NodeValue input);
1032 
1033   /// Create a MulNode with given \p name and output type \p outTy which
1034   /// multiplies \p input with itself to produce an equivalent Square node.
1035   MulNode *createSquare(llvm::StringRef name, TypeRef outTy, NodeValue input);
1036 
1037   /// Create an equivalent LeakyRELU node with given \p name, \p input and slope
1038   /// \p alpha by using a SplatNode and a PRELU node.
1039   PReluNode *createLeakyRELU(llvm::StringRef name, NodeValue input,
1040                              float alpha);
1041 
1042   /// Create an equivalent LeakyRELU node with given \p name, \p outTy, \p input
1043   /// and slope \p alpha by using a SplatNode and a PRELU node.
1044   PReluNode *createLeakyRELU(llvm::StringRef name, TypeRef outTy,
1045                              NodeValue input, float alpha);
1046 
1047   /// Create a node that produces an boolean output of the same shape as
1048   /// \p input in which each element indicates whether or not the corresponding
1049   /// element in \p input is NaN or not.
1050   IsNaNNode *createIsNaN(llvm::StringRef name, NodeValue input);
1051 
1052   /// \returns a ReplaceNaNNode given \p name, \p input, and \p value.
1053   ReplaceNaNNode *createReplaceNaN(llvm::StringRef name, NodeValue input,
1054                                    float value);
1055 
1056   PowNode *createPow(llvm::StringRef name, NodeValue base, float exp);
1057 
1058   SelectNode *createSelect(llvm::StringRef name, NodeValue Cond, NodeValue LHS,
1059                            NodeValue RHS);
1060 
1061   SelectNode *createSelect(llvm::StringRef name, TypeRef outTy, NodeValue Cond,
1062                            NodeValue LHS, NodeValue RHS);
1063 
1064   SplatNode *createSplat(llvm::StringRef name, TypeRef ty, float value);
1065 
1066   TouchNode *createTouch(llvm::StringRef name, TypeRef ty);
1067 
1068   MatMulNode *createMatMul(llvm::StringRef name, NodeValue lhs, NodeValue rhs);
1069 
1070   MatMulNode *createMatMul(llvm::StringRef name, TypeRef outTy, NodeValue lhs,
1071                            NodeValue rhs);
1072 
1073   /// \p lhs and \p rhs are 3d matrices, where the leading dimension is the
1074   /// batch size. For each batch element number i, lhs.slice(i) is multiplied by
1075   /// rhs.slice(i).
1076   BatchMatMulNode *createBatchMatMul(llvm::StringRef name, NodeValue lhs,
1077                                      NodeValue rhs);
1078 
1079   /// Create a node, performing BatchedReduceAdd operation. Output type is
1080   /// based on the input \p batch type with dimensions specified with \p axes
1081   /// removed.
1082   BatchedReduceAddNode *createBatchedReduceAdd(llvm::StringRef name,
1083                                                NodeValue batch,
1084                                                llvm::ArrayRef<unsigned_t> axes);
1085 
1086   /// Create a node, performing BatchedReduceAdd operation. Output type
1087   /// matches input \p outTy type.
1088   BatchedReduceAddNode *createBatchedReduceAdd(llvm::StringRef name,
1089                                                TypeRef outTy, NodeValue batch,
1090                                                llvm::ArrayRef<unsigned_t> axes);
1091 
1092   /// Create a node, performing BatchedReduceMin operation. Output type is
1093   /// based on the input \p batch type with dimensions specified with \p axes
1094   /// removed.
1095   BatchedReduceMinNode *createBatchedReduceMin(llvm::StringRef name,
1096                                                NodeValue batch,
1097                                                llvm::ArrayRef<unsigned_t> axes);
1098 
1099   /// Create a node, performing BatchedReduceMean operation. Output type
1100   /// matches input \p outTy type.
1101   BatchedReduceMeanNode *
1102   createBatchedReduceMean(llvm::StringRef name, TypeRef outTy, NodeValue batch,
1103                           llvm::ArrayRef<unsigned_t> axes);
1104 
1105   /// Create a node, performing BatchedReduceMean operation. Output type is
1106   /// based on the input \p batch type with dimensions specified with \p axes
1107   /// removed.
1108   BatchedReduceMeanNode *
1109   createBatchedReduceMean(llvm::StringRef name, NodeValue batch,
1110                           llvm::ArrayRef<unsigned_t> axes);
1111 
1112   BatchedAddNode *createBatchedAdd(llvm::StringRef name, NodeValue batch,
1113                                    NodeValue slice);
1114 
1115   BatchedAddNode *createBatchedAdd(llvm::StringRef name, TypeRef outTy,
1116                                    NodeValue batch, NodeValue slice);
1117 
1118   /// Create a node performing a Cumulative Sum operation, output type matches
1119   /// \p input type.
1120   CumSumNode *createCumSum(llvm::StringRef name, NodeValue input,
1121                            bool exclusive = false, bool reverse = false);
1122 
1123   /// Implements an operation that accumulates the values in \p data along the
1124   /// first dimension into len(\p lengths) entries by summing together the first
1125   /// lengths[0] values, then the subsequent lengths[1] values, etc.
1126   /// sum(\p lengths) must equal the first dimension of \p data. This operation
1127   /// is similar to SparseLengthsSum but the input is a dense represention
1128   /// instead of a sparse one. In other words, it has already been Gathered.
1129   LengthsSumNode *createLengthsSum(llvm::StringRef name, NodeValue data,
1130                                    NodeValue lengths);
1131 
1132   /// Create a node, performing SparseLengthsSum operation:
1133   /// Gathers slices of the outer-most dimension of Data indexed by Indices
1134   /// vector, and then accumulates them into len(Lengths) entries:
1135   /// first Lengths[0] slices are aggregated to Result[0], next Lengths[1]
1136   /// slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal
1137   /// to len(Indices). \p lengthsMode and \p avgLength represent meta
1138   /// information about the \p lengths, allowing the backend to use a
1139   /// specialized implementation.
1140   SparseLengthsSumNode *
1141   createSparseLengthsSum(llvm::StringRef name, NodeValue data,
1142                          NodeValue indices, NodeValue lengths,
1143                          LengthsMode lengthsMode = LengthsMode::Variable,
1144                          float avgLength = NAN);
1145 
1146   /// Same as SparseLengthsSum, but i-th slice is multiplied by weights[i].
1147   /// len(weights) must be equal to len(indices).
1148   SparseLengthsWeightedSumNode *createSparseLengthsWeightedSum(
1149       llvm::StringRef name, NodeValue data, NodeValue weights,
1150       NodeValue indices, NodeValue lengths,
1151       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1152 
1153   /// Create an EmbeddingBag node. If \p hasEndOffset is true then the node
1154   /// expects an extra offset to be appended to \p offsets which marks the end
1155   /// of the last range. \p lengthsMode and \p avgLength represent meta
1156   /// information about the \p lengths, allowing the backend to use a
1157   /// specialized implementation.
1158   EmbeddingBagNode *createEmbeddingBag(
1159       llvm::StringRef name, NodeValue data, NodeValue weights,
1160       NodeValue indices, NodeValue offsets, bool hasEndOffset = false,
1161       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1162 
1163   /// Create an EmbeddingBagByteRowwiseOffsetsNode node. If \p hasEndOffset is
1164   /// true then the node expects an extra offset to be appended to \p offsets
1165   /// which marks the end of the last range. \p lengthsMode and \p avgLength
1166   /// represent meta information about the \p lengths, allowing the backend to
1167   /// use a specialized implementation.
1168   EmbeddingBagByteRowwiseOffsetsNode *createEmbeddingBagByteRowwiseOffsets(
1169       llvm::StringRef name, NodeValue data, NodeValue weights,
1170       NodeValue indices, NodeValue offsets, bool useFP16Accumulation = false,
1171       bool hasEndOffset = false,
1172       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1173 
1174   /// Same as \ref createEmbeddingBagByteRowwiseOffsets(), but
1175   /// expects float input \p data, which is rowwise-quantized and fused
1176   /// internally. \p fusedElemKind represents the element kind to use for the
1177   /// final fused rowwise-quantized data. If \p hasEndOffset is true then the
1178   /// node expects an extra offset to be appended to \p offsets which marks the
1179   /// end of the last range.
1180   EmbeddingBagByteRowwiseOffsetsNode *createEmbeddingBagByteRowwiseOffsets(
1181       llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
1182       NodeValue offsets, ElemKind fusedElemKind = ElemKind::UInt8FusedQTy,
1183       bool useFP16Accumulation = false, bool hasEndOffset = false,
1184       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1185 
1186   /// Same as \ref createSparseLengthsWeightedSum(), but with \p outTy
1187   /// specified.
1188   SparseLengthsWeightedSumNode *createSparseLengthsWeightedSum(
1189       llvm::StringRef name, TypeRef outTy, NodeValue data, NodeValue weights,
1190       NodeValue indices, NodeValue lengths,
1191       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1192 
1193   /// Creates and \returns a node of \p name, performing the SparseLengthsSum
1194   /// operation, using rowwise quantization for the input \p data with the \p
1195   /// scales and \p offsets as separate input tensors. Gathers slices of the
1196   /// outer-most dimension of data indexed by the \p indices vector, and then
1197   /// accumulates them into len(\p lengths) entries: first Lengths[0] slices are
1198   /// aggregated to Result[0], next Lengths[1] slices are aggregated to
1199   /// Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices).
1200   /// \p precision represents what precision to use for Scale, Offset, and
1201   /// Result. If \p useFP16Accumulation, then internal arithmetic will use FP16
1202   /// accumulation; otherwise defaults to FP32. \p lengthsMode and \p avgLength
1203   /// represent meta information about the \p lengths, allowing the backend to
1204   /// use a specialized implementation.
1205   RowwiseQuantizedSparseLengthsWeightedSumNode *
1206   createRowwiseQuantizedSparseLengthsSum(
1207       llvm::StringRef name, Storage *data, Constant *scales, Constant *offsets,
1208       NodeValue indices, NodeValue lengths,
1209       ElemKind precision = ElemKind::FloatTy, bool useFP16Accumulation = false,
1210       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1211 
1212   /// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but expects
1213   /// float input \p data, which is rowwise-quantized internally.
1214   RowwiseQuantizedSparseLengthsWeightedSumNode *
1215   createRowwiseQuantizedSparseLengthsSum(
1216       llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
1217       quantization::Schema schema, ElemKind precision = ElemKind::FloatTy,
1218       bool useFP16Accumulation = false,
1219       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1220 
1221   /// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but i-th slice is
1222   /// multiplied by weights[i]. len(weights) must be equal to len(indices).
1223   RowwiseQuantizedSparseLengthsWeightedSumNode *
1224   createRowwiseQuantizedSparseLengthsWeightedSum(
1225       llvm::StringRef name, Storage *data, Constant *scales, Constant *offsets,
1226       NodeValue weights, NodeValue indices, NodeValue lengths,
1227       ElemKind precision = ElemKind::FloatTy, bool useFP16Accumulation = false,
1228       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1229 
1230   /// Same as \ref createRowwiseQuantizedSparseLengthsWeightedSum(), but expects
1231   /// float input \p data, which is rowwise-quantized internally.
1232   RowwiseQuantizedSparseLengthsWeightedSumNode *
1233   createRowwiseQuantizedSparseLengthsWeightedSum(
1234       llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
1235       NodeValue lengths, quantization::Schema schema,
1236       ElemKind precision = ElemKind::FloatTy, bool useFP16Accumulation = false,
1237       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1238 
1239   /// Creates and \returns a node of \p name, performing the SparseLengthsSum
1240   /// operation, using fused rowwise quantization for the input \p data wherein
1241   /// the scales and offsets are fused inline with each row of data. \p data
1242   /// must be of a fused ElemKind. Gathers slices of the outer-most dimension of
1243   /// data indexed by the \p indices vector, and then accumulates them into
1244   /// len(\p lengths) entries: first Lengths[0] slices are aggregated to
1245   /// Result[0], next Lengths[1] slices are aggregated to Result[1], etc.  I.e.
1246   /// sum(Lengths) must be equal to len(Indices).  The precision for the Result
1247   /// is determined by the \p data input's ElemKind used for Scale and
1248   /// Offset. If \p useFP16Accumulation, then internal arithmetic will use FP16
1249   /// accumulation; otherwise defaults to FP32. \p lengthsMode and \p avgLength
1250   /// represent meta information about the \p lengths, allowing the backend to
1251   /// use a specialized implementation.
1252   FusedRowwiseQuantizedSparseLengthsSumNode *
1253   createFusedRowwiseQuantizedSparseLengthsSum(
1254       llvm::StringRef name, Storage *data, NodeValue indices, NodeValue lengths,
1255       bool useFP16Accumulation = false,
1256       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1257 
1258   /// Same as \ref createFusedRowwiseQuantizedSparseLengthsSum(), but expects
1259   /// float input \p data, which is rowwise-quantized and fused internally.
1260   /// \p fusedElemKind represents the element kind to use for the final fused
1261   /// rowwise-quantized data.
1262   FusedRowwiseQuantizedSparseLengthsSumNode *
1263   createFusedRowwiseQuantizedSparseLengthsSum(
1264       llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
1265       ElemKind fusedElemKind = ElemKind::UInt8FusedQTy,
1266       bool useFP16Accumulation = false,
1267       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1268 
1269   /// Same as \ref createFusedRowwiseQuantizedSparseLengthsSum(), but i-th slice
1270   /// is multiplied by weights[i]. len(weights) must be equal to len(indices).
1271   FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
1272   createFusedRowwiseQuantizedSparseLengthsWeightedSum(
1273       llvm::StringRef name, NodeValue data, NodeValue weights,
1274       NodeValue indices, NodeValue lengths, bool useFP16Accumulation = false,
1275       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1276 
1277   /// Same as \ref createFusedRowwiseQuantizedSparseLengthsWeightedSum(), but
1278   /// expects float input \p data, which is rowwise-quantized and fused
1279   /// internally. \p fusedElemKind represents the element kind to use for the
1280   /// final fused rowwise-quantized data.
1281   FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
1282   createFusedRowwiseQuantizedSparseLengthsWeightedSum(
1283       llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
1284       NodeValue lengths, ElemKind fusedElemKind = ElemKind::UInt8FusedQTy,
1285       bool useFP16Accumulation = false,
1286       LengthsMode lengthsMode = LengthsMode::Variable, float avgLength = NAN);
1287 
1288   /// Given a vector of segment lengths, calculates offsets of each segment and
1289   /// packs them next to the lengths. For the input vector of length N the
1290   /// output is a Nx2 matrix with (offset, lengths) packaged for each segment.
1291   LengthsToRangesNode *createLengthsToRanges(llvm::StringRef name,
1292                                              NodeValue lengths);
1293 
1294   /// Given a vector of \p lengths, \returns a LengthsRangeFillNode. This Node
1295   /// calculates a range sequence given \p lengths, where the sum of the
1296   /// elements of \p lengths must be no greater than \p maxOutputSize, which is
1297   /// used to set the output type.
1298   LengthsRangeFillNode *createLengthsRangeFill(llvm::StringRef name,
1299                                                NodeValue lengths,
1300                                                unsigned_t maxOutputSize);
1301 
1302   /// Implements an operation that converts the sparse representation given by
1303   /// the pair of \p indices and \p values into a dense representation.
1304   /// This representation contains each value of \p values at the corresponding
1305   /// index given by \p indices. All indices that are not present in \p indices
1306   /// are filled with zeroes. \p indices can contain duplicates, and in this
1307   /// case, the corresponding values in \p values are added.
1308   ///
1309   /// \p dataToInferDim acts as a hint about the shape of the output. The first
1310   /// dimension of the output is the first dimension of this tensor.
1311   SparseToDenseNode *createSparseToDense(llvm::StringRef name,
1312                                          NodeValue indices, NodeValue values,
1313                                          NodeValue dataToInferDim);
1314 
1315   /// Implements an operation that converts the sparse representation given by
1316   /// the pair of \p indices and \p values into a dense representation, which
1317   /// only contains IDs from given \p mask. Indices cannot contain duplicates.
1318   /// \p lengths is used to distinguish elements that belong to different
1319   /// examples of one batch. That is, first \p lengths[0] index-value pairs
1320   /// belong to batch's example 0, next \p lengths[1] pairs belong to example 1
1321   /// and so on.
1322   SparseToDenseMaskNode *
1323   createSparseToDenseMask(llvm::StringRef name, NodeValue indices,
1324                           NodeValue values, NodeValue defaultValue,
1325                           NodeValue lengths, llvm::ArrayRef<dim_t> mask);
1326 
1327   SaveNode *createSave(llvm::StringRef name, NodeValue input);
1328 
1329   /// Creates and \returns a SaveNode of \p input to \p output. If \p skipSuffix
1330   /// then the name used is \p name, otherwise suffix "_save" is appended.
1331   SaveNode *createSave(llvm::StringRef name, NodeValue input,
1332                        Placeholder *output, bool skipSuffix = false);
1333 
1334   /// Create quantization profile node named \p name for the output tensor from
1335   /// \p input in PlaceholderBindings \p bindings. Capture observed node name in
1336   /// quantization profile node as original node can be replaced during lowering
1337   /// phase. Compute the histogram during profiling with \p numHistogramBins.
1338   QuantizationProfileNode *
1339   createQuantizationProfile(PlaceholderBindings &bindings, llvm::StringRef name,
1340                             NodeValue input, dim_t numHistogramBins = 10);
1341 
1342   /// Create lookup table for mapping between quantized numbers.
1343   /// \p input and \p outTy must have quantized type.
1344   /// Table contains all numbers from the quantized range, e.g.,
1345   /// 256 entries for int8. Position 0 in the \p initValues
1346   /// corresponds to the -128 input number, position 255 to 127.
1347   IntLookupTableNode *createIntLookupTable(llvm::StringRef name,
1348                                            NodeValue input,
1349                                            llvm::ArrayRef<int8_t> initValues,
1350                                            TypeRef outTy);
1351 
1352   /// Create quantized tanh.
1353   IntLookupTableNode *createIntTanh(llvm::StringRef name, NodeValue input,
1354                                     TypeRef outTy);
1355 
1356   /// Create quantized sigmoid.
1357   IntLookupTableNode *createIntSigmoid(llvm::StringRef name, NodeValue input,
1358                                        TypeRef outTy);
1359 
1360   TopKNode *createTopK(llvm::StringRef name, NodeValue input, unsigned_t k);
1361 
1362   TopKNode *createTopK(llvm::StringRef name, NodeValue input, unsigned_t k,
1363                        ElemKind outIndicesTyKind);
1364 
1365   /// Gathers entries of the outer-most dimension of \p data indexed by
1366   /// \p indices, and concatenates them. A non-zero \p batchDims specifies the
1367   /// batch, and the result is the concatenation of the operation on each sample
1368   /// in the batch.
1369   GatherNode *createGather(llvm::StringRef name, NodeValue data,
1370                            NodeValue indices, unsigned_t batchDims = 0);
1371 
1372   /// Create a node, performing GatherRanges operation:
1373   /// Gathers entries of \p data in groups specified by the "examples" in
1374   /// \p ranges. Each example in \p ranges contains a list of pairs of
1375   /// indices of the form (index, length) which specify which entries of \p
1376   /// data to gather. The ordering of elements in \p ranges and of pairs
1377   /// within an element is preserved in the output. In addition to the result
1378   /// of gathering ("output"), the lengths of the ranges gathered by each
1379   /// example in \p ranges is also produced as an output ("lengths").
1380   /// \p maxOutputSize is the maximum possible size of "output" and is used to
1381   /// set its type. Users must use "lengths" to interpret "output" correctly.
1382   /// \returns the GatherRangesNode.
1383   GatherRangesNode *createGatherRanges(llvm::StringRef name, NodeValue data,
1384                                        NodeValue ranges,
1385                                        unsigned_t maxOutputSize);
1386 
1387   /// Copies each slice from \p slices into \p data at the corresponding index
1388   /// in \p indices, and \returns this new version of data. For example, given
1389   /// input data {{1,2},{3,4},{5,6}}, slices {{-3,-4}}, and indices {1}, the
1390   /// result is {{1,2},{-3,-4},{5,6}}. If \p cumulative is true, this node adds
1391   /// values instead of copying.
1392   ScatterDataNode *createScatterData(llvm::StringRef name, NodeValue data,
1393                                      NodeValue indices, NodeValue slices,
1394                                      bool cumulative = false);
1395 
1396   /// Given 2D matrix \p data, 1D vector \p lengths (of the same size as width
1397   /// of \p data), and 1D vector \p values (of the same size as sum of
1398   /// \p lengths), expand each row of the \p data to a row of zeros and ones,
1399   /// according to One Hot Encoding. j-th element of resulting i-th row is one
1400   /// iff \p values[j] == \p data[i][some index within range of j].
1401   BatchOneHotNode *createBatchOneHot(llvm::StringRef name, NodeValue data,
1402                                      NodeValue lengths, NodeValue values);
1403 
1404   /// Given Input tensor of [N,H,W,C], where N is the batch
1405   /// axis, H is the height, W is
1406   /// the width, C is the channel or depth. This produces Output tensor of [N,
1407   /// H/blockSize, W/blockSize, C * blockSize * blockSize].
1408   SpaceToDepthNode *createSpaceToDepth(llvm::StringRef name, NodeValue input,
1409                                        unsigned blockSize);
1410 
1411   /// Given \p input tensor, \returns an upsampled tensor which has
1412   /// doubled the size of dimensions N, N-1, N-2...N-numLeadingDims,
1413   /// copying the nearest pixel value to the new locations.
1414   ReshapeNode *createUpsample(llvm::StringRef name, NodeValue input,
1415                               dim_t numLeadingDims);
1416 
1417   /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel
1418   /// or depth, H is the height and W is the width, and \p scale tensor with
1419   /// tensor format same as \p input then ResizeNearest generates an Output
1420   /// tensor with resized spatial dimensions using nearest neighbor
1421   /// interpolation. The Output tensor is of shape [floor(N * \p scale[0]),
1422   /// floor(H * \p scale[1]), floor(W * \p scale[2]),
1423   /// floor(C * \p scale[3])]
1424   ResizeNearestNode *createResizeNearest(llvm::StringRef name, NodeValue input,
1425                                          llvm::ArrayRef<float> scale);
1426 
1427   /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel
1428   /// or depth, H is the height and W is the width, with tensor format same as
1429   /// \p input then ResizeNearest generates an Output tensor with resized
1430   /// spatial dimensions using nearest neighbor interpolation. The Output tensor
1431   /// shape is specified with \p outTy.
1432   ResizeNearestNode *createResizeNearest(llvm::StringRef name, NodeValue input,
1433                                          TypeRef outTy);
1434 
1435   /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel
1436   /// or depth, H is the height and W is the width, and \p scale tensor with
1437   /// tensor format same as \p input then ResizeBilinear generates an Output
1438   /// tensor with resized spatial dimensions using bilinear neighbor
1439   /// interpolation. The Output tensor is of shape [floor(N * \p scale[0]),
1440   /// floor(H * \p scale[1]), floor(W * \p scale[2]),
1441   /// floor(C * \p scale[3])]
1442   ResizeBilinearNode *createResizeBilinear(llvm::StringRef name,
1443                                            NodeValue input,
1444                                            llvm::ArrayRef<float> scale);
1445 
1446   /// Given \p input tensor of [N,H,W,C], where N is the batch, C is the channel
1447   /// or depth, H is the height and W is the width, with tensor format same as
1448   /// \p input then ResizeBilinear generates an Output tensor with resized
1449   /// spatial dimensions using bilinear neighbor interpolation. The Output
1450   /// tensor shape is specified with \p outTy.
1451   ResizeBilinearNode *createResizeBilinear(llvm::StringRef name,
1452                                            NodeValue input, TypeRef outTy);
1453 
1454   /// Create quantization node which transforms floating point tensor to a
1455   /// quantized one with given Scale and Offset. Scale and Offset params are
1456   /// part of the \p outTy.
1457   QuantizeNode *createQuantize(llvm::StringRef name, NodeValue input,
1458                                TypeRef outTy);
1459 
1460   /// Create dequantization node which transforms quantized tensor to a
1461   /// floating point one with given Scale and Offset. Scale and Offset params
1462   /// are part of the \p input. Result dequantization kind is \p k.
1463   DequantizeNode *createDequantize(llvm::StringRef name, NodeValue input,
1464                                    ElemKind k);
1465 
1466   /// Create dequantization node which transforms quantized tensor to a
1467   /// floating point type \p outTy one with given Scale and Offset. Scale and
1468   /// Offset params are part of the \p input.
1469   DequantizeNode *createDequantize(llvm::StringRef name, NodeValue input,
1470                                    TypeRef outTy);
1471 
1472   /// Create transformation for quantized tensors to rescale based on the new
1473   /// Scale and Offset.
1474   RescaleQuantizedNode *createRescaleQuantized(llvm::StringRef name,
1475                                                NodeValue input, TypeRef outTy);
1476 
1477   /// Create a series of nodes that implement a weighted sum. \p data and \p
1478   /// weights should have the same number of elements. The nodes in \p weights
1479   /// should all be of size 1. Each node d_i in \p data is element-wise
1480   /// multiplied by the corresponding weight value w_i found in \p weights,
1481   /// broadcasted to the same shape as d_i, and resulting in r_i. All r_i are
1482   /// element-wise summed, and the final add node in this sum is returned.
1483   Node *createWeightedSum(llvm::StringRef name, llvm::ArrayRef<NodeValue> data,
1484                           llvm::ArrayRef<NodeValue> weights);
1485 
1486   /// Create a series of nodes that implements a two-parameter
1487   /// rowwise Box-Cox transform. For each element of the \p input x, this is
1488   /// defined as:
1489   ///
1490   /// y = ln(max(x + lambda2, 1e-6)), if lambda1 == 0
1491   ///     (max(x + lambda2, 1e-6)^lambda1 - 1)/lambda1, if lambda1 != 0
1492   ///
1493   /// The transform parameters \p lambda1 and \p lambda2 are vectors of size D
1494   /// that are broadcasted to match the size of \p input (NxD). The transform
1495   /// itself is implemented using elementwise Max, Add, Log (if lambda1 == 0),
1496   /// Pow, Splat, Sub, and Div (if lambda1 != 0) nodes with a Splat and Select
1497   /// node to select between the two cases listed above. \returns the final
1498   /// Select node. \p epsilon is used to ensure we do not divide by zero when
1499   /// calculating the lambda == 0 case, as we use a Select to choose which
1500   /// result to use, and so both paths are executed.
1501   Node *createBatchBoxCox(llvm::StringRef name, NodeValue input,
1502                           NodeValue lambda1, NodeValue lambda2,
1503                           float epsilon = std::numeric_limits<float>::min());
1504 
1505   /// Create a Clip node with the given \p name, \p input, minimum clip value
1506   /// \p min, maximum clip value \p max and output type \p outTy.
1507   ClipNode *createClip(llvm::StringRef name, NodeValue input, TypeRef outTy,
1508                        float min, float max);
1509 
1510   /// Create a Clip node with the given \p name, \p input, minimum clip value
1511   /// \p min, maximum clip value \p max. Result type will be implicitly set
1512   /// based on the \p input type.
1513   ClipNode *createClip(llvm::StringRef name, NodeValue input, float min,
1514                        float max);
1515 
1516   /// Creates and \returns a ClipNode to the min/max range of FP16 with \p name
1517   /// of \p input. Result type will be implicitly set based on the \p input
1518   /// type.
1519   ClipNode *createClipMinMaxFP16(llvm::StringRef name, NodeValue input);
1520 
1521   /// Creates and \returns a ClipNode to the min/max range of BFloat16 with \p
1522   /// name of \p input. Result type will be implicitly set based on the \p input
1523   /// type.
1524   ClipNode *createClipMinMaxBFloat16(llvm::StringRef name, NodeValue input);
1525 
1526   /// @name The builder functions below are identical to the builder functions
1527   /// above except that they create nodes that use Placeholder instead of
1528   /// Variables. The methods create and initialize the tensors in the
1529   /// PlaceholderBindings. As soon as we finish the Placeholder migration we'll
1530   /// delete these methods and merge them with the builder methods above. See
1531   /// issue #1334.
1532   ///@{
1533 
1534   BatchNormalizationNode *
1535   createBatchNormalization(PlaceholderBindings &bindings, llvm::StringRef name,
1536                            NodeValue input, unsigned_t channelIdx = 0,
1537                            float epsilon = 1e-5, float momentum = 0.9);
1538 
1539   /// Creates a ConvolutionNode with the given \p name which convolves the 4D
1540   /// \p input. \p kernels defines the size of the height and width dimensions
1541   /// of the convolutional filters. \p stride defines the the number of steps
1542   /// to take in the input for each output cell. \p pads defines how many zero
1543   /// padding cells should be added to the input during convolution. \p group
1544   /// defines the number of groups the input and output channels should be
1545   /// divided into and convolved separately. \p dilation defines factor by
1546   /// which gap between 2 neighboring kernel elements is expanded along each
1547   /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW.
1548   ConvolutionNode *createConv(PlaceholderBindings &bindings,
1549                               llvm::StringRef name, NodeValue input,
1550                               dim_t outChannels,
1551                               llvm::ArrayRef<unsigned_t> kernels,
1552                               llvm::ArrayRef<unsigned_t> strides,
1553                               llvm::ArrayRef<unsigned_t> pads, unsigned_t group,
1554                               unsigned_t dilation = 1,
1555                               ConvolutionLayout layout = NHWC);
1556 
1557   /// Creates a ConvolutionNode with the given \p name which convolves the 4D
1558   /// \p input. \p kernel defines the size of the height and width dimensions of
1559   /// the convolutional filters. \p stride defines the the number of steps to
1560   /// take in the input for each output cell. \p pad defines how many zero
1561   /// padding cells should be added to the input during convolution. \p group
1562   /// defines the number of groups the input and output channels should be
1563   /// divided into and convolved separately.\p dilation defines factor by
1564   /// which gap between 2 neighboring kernel elements is expanded along each
1565   /// axis. \p layout defines the Tensor layout and must be either NHWC or NCHW.
1566   ConvolutionNode *createConv(PlaceholderBindings &bindings,
1567                               llvm::StringRef name, NodeValue input,
1568                               dim_t outChannels, unsigned_t kernel,
1569                               unsigned_t stride, unsigned_t pad,
1570                               unsigned_t group, unsigned_t dilation = 1,
1571                               ConvolutionLayout layout = NHWC);
1572 
1573   /// Creates a Convolution3DNode with the given \p name which convolves the 5D
1574   /// \p input. \p kernels defines the size of the height, width, and depth
1575   /// dimensions of the convolutional filters. \p strides defines the the number
1576   /// of steps to take in the input for each output cell. \p pads defines how
1577   /// many zero padding cells should be added to the input during convolution.
1578   /// \p group defines the number of groups the input and output channels should
1579   /// be divided into and convolved separately.
1580   Convolution3DNode *createConv3D(PlaceholderBindings &bindings,
1581                                   llvm::StringRef name, NodeValue input,
1582                                   dim_t outChannels,
1583                                   llvm::ArrayRef<unsigned_t> kernels,
1584                                   llvm::ArrayRef<unsigned_t> strides,
1585                                   llvm::ArrayRef<unsigned_t> pads,
1586                                   unsigned_t group);
1587 
1588   /// Creates a Convolution3DNode with the given \p name which convolves the 5D
1589   /// \p input. \p kernel defines the size of the height, width, and depth
1590   /// dimensions of the convolutional filters. \p stride defines the the number
1591   /// of steps to take in the input for each output cell. \p pad defines how
1592   /// many zero padding cells should be added to the input during convolution.
1593   /// \p group defines the number of groups the input and output channels should
1594   /// be divided into and convolved separately.
1595   Convolution3DNode *createConv3D(PlaceholderBindings &bindings,
1596                                   llvm::StringRef name, NodeValue input,
1597                                   size_t outChannels, unsigned_t kernel,
1598                                   unsigned_t stride, unsigned_t pad,
1599                                   unsigned_t group);
1600 
1601   /// Creates a ConvTransposeNode with the given \p name which does transposed
1602   /// convolution on the 4D \p input. \p kernels define the size of the height
1603   /// and width dimensions of the convolution filters. \p strides define the
1604   /// number of steps to take in the input for each output cell. \p pads define
1605   /// how many zero padding cells should be added to the input during
1606   /// convolution. \p group defines the number of groups the input and output
1607   /// channels should be divided into and convolved separately.
1608   ConvTransposeNode *createConvTranspose(
1609       PlaceholderBindings &bindings, llvm::StringRef name, NodeValue input,
1610       dim_t outChannels, llvm::ArrayRef<unsigned_t> kernels,
1611       llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
1612       unsigned_t group, unsigned_t dilation = 1);
1613 
1614   /// Creates a ConvTransposeNode with the given \p name which does transposed
1615   /// convolution on the 4D \p input. \p kernel defines the size of the height
1616   /// and width dimensions of the convolution filters. \p stride defines the
1617   /// number of steps to take in the input for each output cell. \p pad defines
1618   /// how many zero padding cells should be added to the input during
1619   /// convolution. \p group defines the number of groups the input and output
1620   /// channels should be divided into and convolved separately.
1621   ConvTransposeNode *createConvTranspose(PlaceholderBindings &bindings,
1622                                          llvm::StringRef name, NodeValue input,
1623                                          dim_t outChannels, unsigned_t kernel,
1624                                          unsigned_t stride, unsigned_t pad,
1625                                          unsigned_t group,
1626                                          unsigned_t dilation = 1);
1627 
1628   /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
1629   /// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened
1630   /// along \p axis. Note, output type is inferred based on the input
1631   /// types. Trainable weight and bias variables are created implicitly.
1632   FullyConnectedNode *createFullyConnected(PlaceholderBindings &bindings,
1633                                            llvm::StringRef name,
1634                                            NodeValue input, dim_t outDepth,
1635                                            unsigned_t axis = 1);
1636 
1637   /// Create an unrolled single-layer Simple RNN cell with \p hiddenSize
1638   /// dimensionality of the hidden state and \p outputSize dimensionality of the
1639   /// output state. \p inputs define the input for the cell at each time step
1640   /// and the number of time steps is equal to the size of the \p inputs. The
1641   /// names of the created variables are prefixed by \p namePrefix.
1642   /// The output variables are written to \p outputs, they represent the
1643   /// activations of the output layer, unrolled over time.
1644   // The dimensionality of the output variables is \p batchSize x \p outputSize.
1645   void createSimpleRNN(PlaceholderBindings &bindings,
1646                        llvm::StringRef namePrefix,
1647                        const llvm::ArrayRef<NodeValue> inputs,
1648                        unsigned batchSize, unsigned hiddenSize,
1649                        unsigned outputSize, std::vector<NodeValue> &outputs);
1650 
1651   /// Create an unrolled single-layer GRU cell with \p hiddenSize
1652   /// dimensionality of the hidden state and \p outputSize dimensionality of the
1653   /// output state. \p inputs define the input for the cell at each time step
1654   /// and the number of time steps is equal to the size of the \p inputs. The
1655   /// names of the created variables are prefixed by \p namePrefix.
1656   /// The output variables are written to \p outputs, they represent the
1657   /// activation of the output layer, unrolled over time.
1658   // The dimensionality of the output variables is \p batchSize x \p outputSize.
1659   void createGRU(PlaceholderBindings &bindings, llvm::StringRef namePrefix,
1660                  const llvm::ArrayRef<NodeValue> inputs, unsigned batchSize,
1661                  unsigned hiddenSize, unsigned outputSize,
1662                  std::vector<NodeValue> &outputs);
1663 
1664   /// Create an unrolled single-layer LSTM cell with \p hiddenSize
1665   /// dimensionality of the hidden state and \p outputSize dimensionality of the
1666   /// output state. \p inputs define the input for the cell at each time step
1667   /// and the number of time steps is equal to the size of the \p inputs. The
1668   /// names of the created variables are prefixed by \p namePrefix.
1669   /// The output variables are written to \p outputs, they represent the
1670   /// activation of the output layer, unrolled over time.
1671   // The dimensionality of the output variables is \p batchSize x \p outputSize.
1672   void createLSTM(PlaceholderBindings &bindings, llvm::StringRef namePrefix,
1673                   const llvm::ArrayRef<NodeValue> inputs, unsigned batchSize,
1674                   unsigned hiddenSize, unsigned outputSize,
1675                   std::vector<NodeValue> &outputs);
1676 
1677   /// Type definition for the direction of an RNN module (RNN, GRU, LSTM).
1678   enum class RnnDirection {
1679     Forward,
1680     Reverse,
1681     Bidirectional,
1682   };
1683 
1684   /// Definition for a lambda used to create an activation node for RNN modules.
1685   using RnnActivation = std::function<Node *(llvm::StringRef, Node *)>;
1686 
1687   /// Create an unrolled multi-layer RNN according to the ONNX definition:
1688   /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#RNN
1689   /// The RNN has the following inputs:
1690   /// - input \p X with size [S, B, ISize].
1691   /// - weigts \p W with size [N, HSize, ISize].
1692   /// - reccurence weights \p R with size [N, HSize, HSize].
1693   /// - bias weights \p B with size [N, 2 * HSize].
1694   /// - initial hidden state \p initial_h with size [N, B, HSize].
1695   /// where S is the sequence length, N is the number of directions, B is the
1696   /// batch size, ISize is the input size and HSize is the hidden size.
1697   /// The RNN has the following outputs:
1698   /// - output \p Y with size [S, N, B, HSize].
1699   /// - final hidden state \p Y_h with size [N, B, HSize].
1700   /// The direction of the instatiated RNN is given by \p direction. The RNN
1701   /// will use the activation functions defined by the \p activations array:
1702   /// - [f] in case the RNN is unidirectional (1 function).
1703   /// - [f] for the forward cell followed by [f] for the reverse cell in
1704   ///    case the RNN is bidirectional (4 functions).
1705   /// The input \p B is optional (assumed 0 if nullptr is provided).
1706   /// The names of all the nodes created are prefixed with \p namePrefix.
1707   void createOnnxRNN(llvm::StringRef namePrefix, NodeValue X, NodeValue W,
1708                      NodeValue R, NodeValue B, NodeValue initial_h,
1709                      NodeValue &Y, NodeValue &Y_h, unsigned hiddenSize,
1710                      RnnDirection direction,
1711                      std::vector<RnnActivation> &activations);
1712 
1713   /// Create an unrolled multi-layer GRU according to the ONNX definition:
1714   /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU
1715   /// The GRU has the following inputs:
1716   /// - input \p X with size [S, B, ISize].
1717   /// - weigts \p W with size [N, 3 * HSize, ISize].
1718   /// - reccurence weights \p R with size [N, 3 * HSize, HSize].
1719   /// - bias weights \p B with size [N, 6 * HSize].
1720   /// - initial hidden state \p initial_h with size [N, B, HSize].
1721   /// where S is the sequence length, N is the number of directions, B is the
1722   /// batch size, ISize is the input size and HSize is the hidden size.
1723   /// The GRU has the following outputs:
1724   /// - output \p Y with size [S, N, B, HSize].
1725   /// - final hidden state \p Y_h with size [N, B, HSize].
1726   /// The direction of the instatiated GRU is given by \p direction. The GRU
1727   /// will use the activation functions defined by the \p activations array:
1728   /// - [f,g] in case the GRU is unidirectional (2 functions).
1729   /// - [f,g] for the forward cell followed by [f,g] for the reverse cell in
1730   ///    case the GRU is bidirectional (4 functions).
1731   /// The input \p B is optional (assumed 0 if nullptr is provided).
1732   /// The names of all the nodes created are prefixed with \p namePrefix.
1733   /// The boolean parameter \p linearBeforeReset defines whether the reset
1734   /// for the previous hidden state occurs before/after the linear expression.
1735   void createOnnxGRU(llvm::StringRef namePrefix, NodeValue X, NodeValue W,
1736                      NodeValue R, NodeValue B, NodeValue initial_h,
1737                      NodeValue &Y, NodeValue &Y_h, unsigned hiddenSize,
1738                      RnnDirection direction,
1739                      std::vector<RnnActivation> &activations,
1740                      bool linearBeforeReset = false);
1741 
1742   /// Create an unrolled multi-layer LSTM according to the ONNX definition:
1743   /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
1744   /// The LSTM has the following inputs:
1745   /// - input \p X with size [S, B, ISize].
1746   /// - weigts \p W with size [N, 4 * HSize, ISize].
1747   /// - reccurence weights \p R with size [N, 4 * HSize, HSize].
1748   /// - bias weights \p B with size [N, 8 * HSize].
1749   /// - initial hidden state \p initial_h with size [N, B, HSize].
1750   /// - initial cell state \p initial_c with size [N, B, HSize].
1751   /// - peephole weights \p P with size [N, 3 * HSize].
1752   /// where S is the sequence length, N is the number of directions, B is the
1753   /// batch size, ISize is the input size and HSize is the hidden size.
1754   /// The LSTM has the following outputs:
1755   /// - output \p Y with size [S, N, B, HSize].
1756   /// - final hidden state \p Y_h with size [N, B, HSize].
1757   /// - final cell state \p Y_c with size [N, B, HSize].
1758   /// The direction of the instatiated LSTM is given by \p direction. The LSTM
1759   /// will use the activation functions defined by \p activations array:
1760   /// - [f,g,h] in case the LSTM is unidirectional (3 functions).
1761   /// - [f,g,h] for the forward cell followed by [f,g,h] for the reverse cell in
1762   ///    case the LSTM is bidirectional (6 functions).
1763   /// The inputs \p B and \p P are optional (assumed 0 if nullptr is provided).
1764   /// The names of all the nodes created are prefixed with \p namePrefix.
1765   /// The boolean parameter \p inputForget defines whether the input and forget
1766   /// gates should be coupled (compute the input gate from the forget gate).
1767   void createOnnxLSTM(llvm::StringRef namePrefix, NodeValue X, NodeValue W,
1768                       NodeValue R, NodeValue B, NodeValue initial_h,
1769                       NodeValue initial_c, NodeValue P, NodeValue &Y,
1770                       NodeValue &Y_h, NodeValue &Y_c, unsigned hiddenSize,
1771                       RnnDirection direction,
1772                       std::vector<RnnActivation> &activations,
1773                       bool inputForget = false);
1774   /// @}
1775 
1776   /// Create a TraceEvent in the runtime profile, which triggers collection of
1777   /// runtime statistics.
1778   TraceEventNode *createTraceEvent(llvm::StringRef eventName,
1779                                    llvm::StringRef eventType, Node *data,
1780                                    unsigned index);
1781 
1782   /// Creates NMSv4 node that does NMS for one class.
1783   /// Inputs
1784   /// - \p boxes Tensor with box coordinates.
1785   /// - \p scores Tensor with scores per box.
1786   /// - \p centerPointBox Indicates format of the box per ONNX spec.
1787   /// - \p iouThreshold Threshold for box overlap.
1788   /// - \p scoreThreshold Threshold for box scores.
1789   NonMaxSuppressionNode *
1790   createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes,
1791                             NodeValue scores, int64_t centerPointBox,
1792                             int64_t maxOutputBoxesPerClass, float iouThreshold,
1793                             float scoreThreshold);
1794 
1795   /// Creates NMSv4 node that does NMS for one class.
1796   /// Inputs
1797   /// - \p boxes Tensor with box coordinates.
1798   /// - \p scores Tensor with scores per box.
1799   /// - \p centerPointBox Indicates format of the box per ONNX spec.
1800   /// - \p iouThreshold Threshold for box overlap.
1801   /// - \p scoreThreshold Threshold for box scores.
1802   /// - \p ElemKind Output ElemKind.
1803   NonMaxSuppressionNode *
1804   createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes,
1805                             NodeValue scores, int64_t centerPointBox,
1806                             int64_t maxOutputBoxesPerClass, float iouThreshold,
1807                             float scoreThreshold, ElemKind elTy);
1808 
1809   /// Creates NMSv4 node that does NMS for one class.
1810   /// Inputs
1811   /// - \p boxes Tensor with box coordinates.
1812   /// - \p scores Tensor with scores per box.
1813   /// - \p centerPointBox Indicates format of the box per ONNX spec.
1814   /// - \p iouThreshold Threshold for box overlap.
1815   /// - \p scoreThreshold Threshold for box scores.
1816   /// - \p indicesTy Type of indices output.
1817   /// - \p numberOfSelectedIndicesTy \p Type of second output for number of
1818   /// boxes detected.
1819   NonMaxSuppressionNode *
1820   createNonMaxSuppressionV4(llvm::StringRef name, NodeValue boxes,
1821                             NodeValue scores, int64_t centerPointBox,
1822                             int64_t maxOutputBoxesPerClass, float iouThreshold,
1823                             float scoreThreshold, TypeRef indicesTy,
1824                             TypeRef numberOfSelectedIndicesTy);
1825 
1826   /// Performs class wise NMS based on ONNX specification, with padding and ONNX
1827   /// layout output.
1828   /// Inputs
1829   /// - \p boxes Tensor with box coordinates.
1830   /// - \p scores Tensor with scores per box.
1831   /// - \p centerPointBox Indicates format of the box per ONNX spec.
1832   /// - \p iouThreshold Threshold for box overlap.
1833   /// - \p scoreThreshold Threshold for box scores.
1834   NonMaxSuppressionNode *
1835   createNonMaxSuppressionONNX(llvm::StringRef name, NodeValue boxes,
1836                               NodeValue scores, int64_t centerPointBox,
1837                               int64_t maxOutputBoxesPerClass,
1838                               float iouThreshold, float scoreThreshold);
1839 
1840   /// Performs class wise NMS based on ONNX specification, with padding and ONNX
1841   /// layout output.
1842   /// Inputs
1843   /// - \p boxes Tensor with box coordinates.
1844   /// - \p scores Tensor with scores per box.
1845   /// - \p centerPointBox Indicates format of the box per ONNX spec.
1846   /// - \p iouThreshold Threshold for box overlap.
1847   /// - \p scoreThreshold Threshold for box scores.
1848   NonMaxSuppressionNode *createNonMaxSuppressionONNX(
1849       llvm::StringRef name, NodeValue boxes, NodeValue scores,
1850       int64_t centerPointBox, int64_t maxOutputBoxesPerClass,
1851       float iouThreshold, float scoreThreshold, ElemKind elTy);
1852 
1853   /// Performs class wise NMS based on ONNX specification, with padding and ONNX
1854   /// layout output.
1855   /// Inputs
1856   /// - \p boxes Tensor with box coordinates.
1857   /// - \p scores Tensor with scores per box.
1858   /// - \p centerPointBox Indicates format of the box per ONNX spec.
1859   /// - \p iouThreshold Threshold for box overlap.
1860   /// - \p scoreThreshold Threshold for box scores.
1861   NonMaxSuppressionNode *createNonMaxSuppressionONNX(
1862       llvm::StringRef name, NodeValue boxes, NodeValue scores,
1863       int64_t centerPointBox, int64_t maxOutputBoxesPerClass,
1864       float iouThreshold, float scoreThreshold, TypeRef indicesTy);
1865 
1866   /// Create a constant node with a 1D cosine windowing function defined as:
1867   /// w[n] = 0.5 - 0.5 * cos(2 * pi * n / N) for n = 0 .. N - 1 where N
1868   /// is the window \p length. The node name will be \p name.
1869   Constant *createCosineWindow(llvm::StringRef name, dim_t length);
1870 
1871   /// Create a constant node with the twiddle factors for a 1D complex FFT:
1872   /// W(N, k) = exp(-j * 2 * pi * k / N) for k = 0 ... N -1, where N is the
1873   /// \p fftLength. The constant node will contain 2 * \p fftLength real float
1874   /// values corresponding to \p fftLength complex values with the real and
1875   /// imaginary parts interleaved: real[0], imag[0], real[1], imag[1], etc.
1876   /// The node name will be \p name.
1877   Constant *createFFTTwiddleFactors(llvm::StringRef name, dim_t fftLength);
1878 
1879   /// Create a constant node with the bit reverse indices for a 1D FFT, that
1880   /// is the corresponding index obtained after reversing the bit order for
1881   /// each of the values k = 0 ... N -1 where N is the \p fftLength. The node
1882   /// will contain \p fftLength int32 values. The node name will be \p name.
1883   Constant *createFFTBitReverseIndices(llvm::StringRef name, dim_t fftLength);
1884 
1885   /// Create a constant node with the complex weights used to map the results
1886   /// of N/2 point complex FFT to a N point real FFT. This allows an efficient
1887   /// implementation of the N point FFT for a real data x[n] with n = 0 .. N-1
1888   /// by first computing the N/2 complex FFT G[k] for the complex signal g[n]
1889   /// defined as g[n] = x[2*n+0] + j * x[2*n+1] with n = 0 ... N/2-1 and then
1890   /// computing the final N point FFT X[k] for the original data x[n] by using
1891   /// X[k] = G[k] * A[k] + conj(G[N/2-k]) * (1 - A[k]) for k = 0 ... N/2 (for
1892   /// a real signal the FFT is conjugate symmetrical and therefore only the
1893   /// first N/2+1 output points of X[k] should be computed, the others being
1894   /// redundant). The relation should also use the definitions G[N/2] = G[0] and
1895   /// then A[k] = 1/2 * (1 - j * exp(-j * 2 * pi * k / N)) for k = 0 ... N/2.
1896   /// The FFT length parameter N is given as \p fftLength. This constant node
1897   /// will contain the complex values of A[k] for k = 0 ... L-1 where L is the
1898   /// sequence length given as \p outLength (the required length L is smaller
1899   /// than N/2+1 since A[k] has such properties that the second half of the
1900   /// sequence can be easily deduced from first half). This constant node will
1901   /// contain 2 * \p outLength real float values corresponding to \p outLength
1902   /// complex values A[k] with the real and imaginary parts interleaved.
1903   Constant *createFFTComplexToRealWeights(llvm::StringRef name, dim_t fftLength,
1904                                           dim_t outLength);
1905 
1906   /// This node computes the spectrogram of a 1D mono audio signal \p input by
1907   /// extracting windows of size \p windowSize with stride \p windowStride and
1908   /// computing for each window the spectrum power (magnitude squared) or simply
1909   /// the magnitude depending on the flag \p magnitudeSquared. If the length of
1910   /// the \p input is [inputLength] samples then the size of the spectrogram is
1911   /// [windowCount, fftLength/2+1] where:
1912   /// - windowCount = floor((inputLength-windowSize)/windowStride)+1 is the
1913   ///   number of windows extracted from the input.
1914   /// - fftLength is the FFT length used to compute the spectrogram which is the
1915   ///   next power of 2 (e.g. for a window size of 640 the fftLength is 1024).
1916   /// The input audio data values are commonly float values scaled in the range
1917   /// [-1.0, 1.0]. If the audio data is decoded from a WAV file into int8/int16
1918   /// values then those values are commonly scaled down with 2^7/2^15 before
1919   /// using this node. The node name will be \p name. This node is inspired from
1920   /// TensorFlow (tensorflow.python.ops.gen_audio_ops.audio_spectrogram).
1921   AudioSpectrogramNode *createAudioSpectrogram(llvm::StringRef name,
1922                                                NodeValue input,
1923                                                int64_t windowSize,
1924                                                int64_t windowStride,
1925                                                bool magnitudeSquared = true);
1926 
1927   /// Create as constants the Mel weights \p melWeights and ranges \p melRanges
1928   /// required for the MFCC (Mel Frequency Cepstral Coefficient) transform for a
1929   /// spectrogram of length \p spectrogramLength (which must be of the form
1930   /// 2 ^ N + 1) obtained for an audio signal with the given \p sampleRate
1931   /// (in Hertz) by mapping the spectrogram coefficients in \p filterBankCount
1932   /// bins on a Mel scale between \p lowerFrequency and \p upperFrequency
1933   /// (in Hertz) using a filterbank of triangular windows. The constant nodes
1934   /// will be named using \p prefix.
1935   void createMelWeights(llvm::StringRef prefix, dim_t spectrogramLength,
1936                         float sampleRate, float lowerFrequency,
1937                         float upperFrequency, dim_t filterBankCount,
1938                         Constant *&melWeights, Constant *&melRanges);
1939 
1940   /// Create the DCT-II transform matrix coefficients as a constant defined as:
1941   /// d[k][n] = sqrt(2 / N) * cos(pi / N * (n + 1/2) * k) with n = 0 .. N - 1
1942   /// and k = 0 .. K - 1 where \p N is the input data length and \p K is the
1943   /// output data length. The common case is that for which the input length
1944   /// \p N is equal to the output length \p K but a separate output length
1945   /// argument \p K <= \p N allows creating a partial DCT matrix used to compute
1946   /// only the first \p K results from the full DCT-II transform. The DCT matrix
1947   /// size will be \p K x \p N.  The node name will be \p name.
1948   Constant *createDCTMat(llvm::StringRef name, dim_t N, dim_t K);
1949 
1950   /// Computes the MFCC (Mel Frequency Cepstral Coefficient) for the given
1951   /// \p spectrogram and is commonly used as feature extractor for voice/speech
1952   /// audio data in voice command or keyword spotting applications. The input
1953   /// \p spectrogram is a power spectrogram and not a magnitude (computed using
1954   /// the 'AudioSpectrogram' node with the 'magnitudeSquared' flag set to True).
1955   /// The MFCC transform is computed using the given \p sampleRate (in Hertz)
1956   /// by mapping the spectrogram coefficients in \p filterBankCount bins on a
1957   /// Mel scale between \p lowerFrequency and \p upperFrequency (in Hertz) using
1958   /// a filterbank of triangular windows, taking the natural logarithm and then
1959   /// keeping the first \p numCoefficients from the DCT-II transform. If the
1960   /// input \p spectrogram size is [windowCount, spectrogramLen] then the output
1961   /// node size will be [windowCount, numCoefficients] since the MFCC transform
1962   /// is performed separately for each window of [spectrogramLen] input samples
1963   /// by yielding \p numCoefficients output samples. This node is inspired from
1964   /// TensorFlow (tensorflow.python.ops.gen_audio_ops.mfcc).
1965   MFCCNode *createMFCC(llvm::StringRef name, NodeValue spectrogram,
1966                        float sampleRate, float lowerFrequency,
1967                        float upperFrequency, int64_t filterBankCount,
1968                        int64_t numCoefficients);
1969 
1970   /// Erase the node \p N from the Function.
1971   void eraseNode(Node *N);
1972 
1973   /// Erase the node \p I from the Function.
1974   void eraseNode(NodesList::iterator I);
1975 
1976   /// Clone the current function into a new function with the name \p newName in
1977   /// the same module. If \p map is non-null then the procedure records the
1978   /// mapping between the old node to the new node in \p map. If \p currToNewMap
1979   /// is non-null it is used as the initial state of the currToNew map inside
1980   /// the cloner.
1981   /// \returns a new function that is a copy of the current function.
1982   Function *clone(llvm::StringRef newName,
1983                   llvm::DenseMap<const Node *, Node *> *map = nullptr,
1984                   llvm::DenseMap<const Node *, Node *> *currToNewMap = nullptr);
1985 
1986   /// Clone the current function into a user-provided function \p newF. The
1987   /// function \p newF is not automatically added to a module by the clone call.
1988   /// If \p map is non-null then the procedure records the mapping between the
1989   /// old node to the new node in \p map. If \p currToNewMap is non-null it is
1990   /// used as the initial state of the currToNew map inside the cloner. \returns
1991   /// a user-provided function \p newF that now contains a clone of the current
1992   /// function.
1993   Function *
1994   clone(Function *newF, llvm::DenseMap<const Node *, Node *> *map = nullptr,
1995         llvm::DenseMap<const Node *, Node *> *currToNewMap = nullptr) const;
1996 
1997   /// Verify the correctness of the Function. If \p backend is provided, checks
1998   /// backend-specific layout requirements. Else checks the requirements based
1999   /// on Glow's "canonical" layout. \returns true when the function is valid.
2000   /// False otherwise.
2001   bool verify(const Backend *backend = nullptr) const;
2002 
2003   /// Dump a textual representation of the Function into provided output stream.
2004   void dump() const;
2005 
2006   /// Dump a textual representation of the Function to std::string. If
2007   /// \p skipUsersForStorage then user counts for Storage will not be dumped.
2008   /// If \p skipName then the name of the Function will not be dumped.
2009   std::string toString(bool skipUsersForStorage = false,
2010                        bool skipName = false) const;
2011 
2012   /// \returns a hash code of the function.
2013   llvm::hash_code getHash() const;
2014 
2015   /// Dump a textual representation of the Function into default output stream.
2016   /// If \p skipUsersForStorage then user counts for Storage will not be dumped.
2017   /// If \p skipName then the name of the Function will not be dumped.
2018   void dump(llvm::raw_ostream &os, bool skipUsersForStorage = false,
2019             bool skipName = false) const;
2020 
2021   /// Dump a dotty graph that depicts the function into a file.
2022   /// \returns full path to the file.
2023   std::string dumpDAG();
2024 
2025   /// Dump a dotty graph that depicts the function.
2026   void dumpDAG(llvm::StringRef dotFilename);
2027 
2028   /// Dump a dotty graph that depicts the function.
2029   void dumpDAG(const char *dotFilename);
2030 
2031   /// \returns the list of nodes that the Function owns.
getNodes()2032   NodesList &getNodes() { return nodes_; }
2033 
getNodes()2034   const NodesList &getNodes() const { return nodes_; }
2035 
2036   /// \returns a node with the name \p name or nullptr if no node was found.
2037   Node *getNodeByName(llvm::StringRef name);
2038 
2039   /// \returns a node value using the \p name which has the same format as the
2040   /// one used by the \ref NodeValue::generateNodeOutputName which is
2041   /// "nodeName:outputNumber". The returned node value has a nullptr for the
2042   /// node if not found in the Function or if the node has no outputs (for
2043   /// example SaveNode). The searched node value can be one of a graph node,
2044   /// constant or placeholder.
2045   NodeValue getNodeValueByName(llvm::StringRef name);
2046 
2047   /// \returns pointer to the class member for the nodes list.
getNodesMemberPtr()2048   static NodesList Function::*getNodesMemberPtr() { return &Function::nodes_; }
2049 
2050   /// Randomize all of the Constants in the function. If a Constant with users
2051   /// in this Function also has users in other Functions then this will result
2052   /// in a FATAL. \p ignoredConstants is a map Kinds of nodes to the input
2053   /// indices for that node that should be ignored (not randomized).
2054   void randomizeConstants(
2055       const std::map<Kinded::Kind, std::set<unsigned>> &ignoredConstants = {});
2056 };
2057 
2058 struct TrainingConfig;
2059 
2060 using VariableGradientsList =
2061     std::list<std::pair<Placeholder *, Placeholder *>>;
2062 
2063 /// Create a new Function that 'trains' the input Function. We differentiate the
2064 /// nodes and insert code to update the weights based on the \p config
2065 /// parameters.
2066 /// If \p varGrads is set then instead of inserting code to update the weights,
2067 /// the procedure adds code to record the last gradient value: a list of
2068 /// (var, grad_var) pairs associating variables with their gradient variables.
2069 /// This feature is used by the gradient-check unit tests.
2070 /// \returns a new function with the name \p newFuncName.
2071 Function *differentiate(Function *F, const TrainingConfig &config,
2072                         llvm::StringRef newFuncName = "",
2073                         VariableGradientsList *varGrads = nullptr);
2074 
2075 /// \returns the first SaveNode user of the placeholder \p PH or
2076 /// nullptr if none are found.
2077 SaveNode *getOutputSave(Function *F, Placeholder *PH);
2078 
2079 /// Clone \p node and its sources into \p newF using old-to-new mapping \p
2080 /// currToNew.
2081 Node *recursiveClone(Function *newF, Node *node, NodeMap &currToNew);
2082 
2083 /// If \p PH is an output placeholder in the Function \p F,
2084 /// \returns true.
2085 /// This is determined by checking if the PH has a user which uses the PH as an
2086 /// overwritten input.
2087 bool isOutput(const Placeholder *PH, const Function &F);
2088 
2089 /// If \p PH is an input placeholderin the Function \p F,
2090 /// \returns true.
2091 /// This is determined by checking if the PH is the input to a saveNode or is
2092 /// used by a non saveNode.
2093 bool isInput(const Placeholder *PH, const Function &F);
2094 
2095 /// Helper vectors for common transpose shuffles.
2096 #define NCHW2NHWC                                                              \
2097   { 0u, 2u, 3u, 1u }
2098 #define NCTHW2NTHWC                                                            \
2099   { 0u, 2u, 3u, 4u, 1u }
2100 #define NHWC2NCHW                                                              \
2101   { 0u, 3u, 1u, 2u }
2102 #define NTHWC2NCTHW                                                            \
2103   { 0u, 4u, 1u, 2u, 3u }
2104 #define HWCN2NHWC                                                              \
2105   { 3u, 0u, 1u, 2u }
2106 #define NHWC2HWNC                                                              \
2107   { 1u, 2u, 0u, 3u }
2108 #define CNHW2NHWC                                                              \
2109   { 1u, 2u, 3u, 0u }
2110 
2111 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module &mod);
2112 
2113 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module *mod);
2114 
2115 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function &F);
2116 
2117 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function *F);
2118 
2119 /// \returns whether the Convolution node \p node is equivalent with a
2120 /// FullyConnected node. This happens for a 2D NHWC Convolution with 1x1 filter
2121 /// with strides 1, pads 0, group 1 and dilations 1.
2122 bool isConvolutionSameAsFullyConnected(const ConvolutionNode *node,
2123                                        bool enfoceInput1x1 = false);
2124 
2125 /// \returns whether the Gemm node \p node is equivalent with a FullyConnected
2126 /// node. This happens when alpha and beta are 1.0 and the C operand is 1D.
2127 bool isGemmSameAsFullyConnected(const GemmNode *node);
2128 
2129 } // namespace glow
2130 
2131 #endif // GLOW_GRAPH_GRAPH_H
2132