1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *
22  * \file src/relay/transforms/fuse_ops.cc
23  *
24  * \brief This is a backend-aware optimization pass.
25  *   Fuse necessary ops into a single one.
26  */
27 #include <tvm/relay/analysis.h>
28 #include <tvm/relay/expr_functor.h>
29 #include <tvm/relay/op_attr_types.h>
30 #include <tvm/relay/transform.h>
31 #include <tvm/tir/op.h>
32 
33 #include "../../support/arena.h"
34 #include "pass_util.h"
35 #include "pattern_util.h"
36 
37 namespace tvm {
38 namespace relay {
39 
40 /*
41   Note on Fusing algorithm:
42 
43   The main challenge of general fusor is to handle possible diamond shape branches,
44   in the following graph, conv2d can be fused to elemwise add.
45 
46             conv2d
47             /  |  \
48            /   |   \
49          op    op   op
50           \    |    /
51            \   |   /
52           elemwise add
53                |
54 
55   However, at the point of conv2d we do not necessarily know that all the future paths
56   will merge at the elemwise add. The fusion algorithm applies post-dominator analysis.
57 
58   The immediate post-dominator of a node defined by the closest node where all the future path goes
59   into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm
60   is as follows:
61 
62   - Construct a DAG of dataflow graph for dominator analysis
63   - Construct a post-dominator tree which gives immediate post dominator of each node.
64   - Run fusion algorithm with the given post-dominator information.
65 
66   Note that, because we run analysis on a DAG, we use a single pass post-dominator
67   tree construction algorithm via LCA, which is simpler than the full version that handles cycles.
68 
69   The fusion algorithm traverses from each node and checks if it can be fused to its
70   immediate post dominator. It has to check the following things:
71 
72   - CheckPath: check all the path between a node and its immediate post-dominator
73                satisfies the fuse condition.
74   - Note that these intermediate node can already be fused with another nodes, the algorithm
75       will still run correctly.
76   - CommitFuse: mark all the nodes between source and post-dominator as the same group.
77   - We use an Union-Find data structure to manage the groups.
78 */
79 using support::LinkedList;
80 using support::LinkNode;
81 
82 constexpr uint32_t kMaxFusedOps = 256;
83 
84 static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion");
85 
86 TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer);
87 
88 /*!
89  * \brief Indexed data flow graph in forward direction.
90  *  This is a temporary data structure used for operator fusion analysis.
91  *
92  *  This data structure only captures the dataflow fragment and
93  *  could ignore blocks like let by simply ordering each dataflow block
94  *  and mark the output node as extern_ref;
95  */
96 class IndexedForwardGraph {
97  public:
98   struct Node;
99   /*!
100    * The forward edge in the dataflow graph.
101    */
102   struct Edge {
103     /*! \brief The corresponding node */
104     Node* node{nullptr};
105     /*! \brief The respective pattern of this op */
106     OpPatternKind pattern{kOpaque};
107   };
108   /*! \brief A node in the graph. */
109   struct Node {
110     /*! \brief weak reference to the corresponding edge. */
111     const tvm::Object* ref{nullptr};
112     /*! \brief The index of the node in topological order. */
113     size_t index{0};
114     /*! \brief Whether this node is referenced by external source */
115     bool extern_ref{false};
116     /*! \brief The general pattern in the node */
117     OpPatternKind pattern{kOpaque};
118     /*! \brief The outputs of the node. */
119     LinkedList<Edge> outputs;
120   };
121   /*! \brief The node map that maps node to graph */
122   std::unordered_map<const tvm::Object*, Node*> node_map;
123   /*! \brief All the nodes in post DFS order */
124   std::vector<Node*> post_dfs_order;
125 
126   /*! \brief Dump the graph into string. */
DebugDump()127   void DebugDump() {
128     std::ostringstream os;
129     for (size_t i = 0; i < post_dfs_order.size(); ++i) {
130       Node* node = post_dfs_order[i];
131       os << "node[" << i << "], " << GetRef<ObjectRef>(node->ref) << " outputs=[";
132       for (auto* link = node->outputs.head; link != nullptr; link = link->next) {
133         os << link->value.node->index << ", ";
134       }
135       os << "]\n";
136     }
137     LOG(INFO) << os.str();
138   }
139   /*!
140    * \brief create a indexed forward graph.
141    * \param arena The arena used for data allocation.
142    * \param body The body of the expression to create a graph.
143    */
144   static IndexedForwardGraph Create(support::Arena* arena, const Expr& body);
145 
146  private:
147   class Creator;
148 };
149 
150 // Creator of post dominator tree of the dataflow
151 class IndexedForwardGraph::Creator : private ExprVisitor {
152  public:
Creator(support::Arena * arena)153   explicit Creator(support::Arena* arena) : arena_(arena) {}
154 
Prepare(const Expr & body)155   IndexedForwardGraph Prepare(const Expr& body) {
156     this->Update(body, nullptr, kOpaque);
157     this->VisitExpr(body);
158     return std::move(graph_);
159   }
160 
161  private:
162   /*! \brief allocator of all the internal node object */
163   support::Arena* arena_;
164   // The output.
165   IndexedForwardGraph graph_;
166   // attribute equal comparator
167   StructuralEqual attr_equal_;
168   // Update the message stored at the node.
Update(const Expr & node,IndexedForwardGraph::Node * parent,OpPatternKind pattern)169   void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) {
170     const tvm::Object* key = node.get();
171     IndexedForwardGraph::Node* current;
172     auto it = graph_.node_map.find(key);
173     if (it != graph_.node_map.end()) {
174       current = it->second;
175     } else {
176       current = arena_->make<IndexedForwardGraph::Node>();
177       graph_.node_map[key] = current;
178     }
179     if (parent != nullptr) {
180       auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge> >();
181       link->value.node = parent;
182       link->value.pattern = pattern;
183       current->outputs.Push(link);
184     } else {
185       current->extern_ref = true;
186     }
187   }
188 
AddNode(const tvm::Object * key)189   void AddNode(const tvm::Object* key) {
190     auto it = graph_.node_map.find(key);
191     CHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef<ObjectRef>(key);
192     IndexedForwardGraph::Node* node = it->second;
193     CHECK(node->ref == nullptr);
194     node->ref = key;
195     node->index = graph_.post_dfs_order.size();
196     graph_.post_dfs_order.push_back(node);
197   }
198 
199   // Post order tree
VisitExpr_(const FunctionNode * op)200   void VisitExpr_(const FunctionNode* op) final {
201     // Skip the function that should be handled by external codegen.
202     if (op->GetAttr<String>(attr::kCompiler).defined()) return;
203 
204     for (auto param : op->params) {
205       this->Update(param, nullptr, kOpaque);
206     }
207     this->Update(op->body, nullptr, kOpaque);
208     ExprVisitor::VisitExpr_(op);
209   }
210 
VisitExpr_(const ConstantNode * op)211   void VisitExpr_(const ConstantNode* op) final {
212     this->AddNode(op);
213     Node* node = graph_.node_map.at(op);
214     DataType dtype = DataType(op->data->dtype);
215     // This rule must be consistent with code generator.
216     bool is_simple_const =
217         (dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) ||
218          dtype == DataType::Float(64) || dtype == DataType::Bool());
219     if (op->is_scalar() && is_simple_const) {
220       node->pattern = kElemWise;
221     } else {
222       // for now, mark non-scalar constant
223       // as opaque, we will not choose to fuse it.
224       node->pattern = kOpaque;
225     }
226   }
227 
VisitExpr_(const CallNode * call)228   void VisitExpr_(const CallNode* call) final {
229     CHECK(graph_.node_map.count(call));
230     Node* node = graph_.node_map.at(call);
231     static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
232     // Now we set the pattern of this call.
233     //
234     // If we see a call mentioning an operator we should mark it with its
235     // annotated pattern.
236     //
237     // If the pattern is not annotated we will default to opaque.
238     //
239     // Finally if the operator position is not a call node we will
240     // need to call Update, as it may be an arbitrary expression.
241     OpPatternKind op_pattern = kOpaque;
242     if (const OpNode* opnode = call->op.as<OpNode>()) {
243       auto op = GetRef<Op>(opnode);
244       if (IsDynamic(call->checked_type()) && IsDataDependant(call)) {
245         // output of a shape func can't be fed to a data-dependent shape func
246         op_pattern = kOpaque;
247       } else {
248         op_pattern = static_cast<OpPatternKind>(fpattern[op]);
249       }
250     } else {
251       this->Update(call->op, node, kOpaque);
252     }
253 
254     node->pattern = op_pattern;
255     this->Update(call->op, nullptr, kOpaque);
256     const auto* rtype = call->checked_type().as<TensorTypeNode>();
257     // pass the analysis back to all the children it references.
258     for (size_t i = 0; i < call->args.size(); ++i) {
259       const auto* arg_type = call->args[i]->checked_type().as<TensorTypeNode>();
260       // specifically check if result type is the same as arguments type
261       OpPatternKind edge_pattern = op_pattern;
262       if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr &&
263           attr_equal_(rtype->shape, arg_type->shape)) {
264         edge_pattern = kElemWise;
265       }
266       this->Update(call->args[i], node, edge_pattern);
267     }
268     ExprVisitor::VisitExpr_(call);
269     this->AddNode(call);
270   }
271 
VisitExpr_(const TupleNode * op)272   void VisitExpr_(const TupleNode* op) final {
273     CHECK(graph_.node_map.count(op));
274     Node* tuple_node = graph_.node_map.at(op);
275     tuple_node->pattern = kTuple;
276     for (const Expr& field : op->fields) {
277       if (field->checked_type().as<TensorTypeNode>()) {
278         this->Update(field, tuple_node, kInjective);
279       } else {
280         this->Update(field, nullptr, kOpaque);
281       }
282     }
283     ExprVisitor::VisitExpr_(op);
284     this->AddNode(op);
285   }
286 
VisitExpr_(const TupleGetItemNode * op)287   void VisitExpr_(const TupleGetItemNode* op) final {
288     auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>();
289     CHECK(tuple_type);
290     // When TVM lowers a fused function, it expects all arguments to be a Tensor or
291     // a tuple containing only Tensors. But this tuple may contain a reference or
292     // another tuple. To avoid modifying codegen logic, we do not allow fusing through this node
293     // if the tuple contains such non Tensor fields. However, all fields will be recursively
294     // visited via call to ExprVisitor::VisitExpr_(op) below and corresponding visitor methods.
295     bool has_non_tensor = false;
296     for (auto ty : tuple_type->fields) {
297       if (!ty.as<TensorTypeNode>()) {
298         has_non_tensor = true;
299         break;
300       }
301     }
302     if (has_non_tensor) {
303       this->Update(op->tuple, nullptr, kOpaque);
304     } else {
305       CHECK(graph_.node_map.count(op));
306       Node* node = graph_.node_map.at(op);
307       node->pattern = kInjective;
308       this->Update(op->tuple, node, kInjective);
309     }
310     ExprVisitor::VisitExpr_(op);
311     this->AddNode(op);
312   }
313 
VisitExpr_(const VarNode * op)314   void VisitExpr_(const VarNode* op) final { this->AddNode(op); }
315 
VisitExpr_(const LetNode * op)316   void VisitExpr_(const LetNode* op) final {
317     // do not fuse through let.
318     this->Update(op->var, nullptr, kOpaque);
319     this->Update(op->value, nullptr, kOpaque);
320     this->Update(op->body, nullptr, kOpaque);
321     ExprVisitor::VisitExpr_(op);
322     this->AddNode(op);
323   }
324 
VisitExpr_(const IfNode * op)325   void VisitExpr_(const IfNode* op) final {
326     // do not fuse through if.
327     this->Update(op->cond, nullptr, kOpaque);
328     this->Update(op->true_branch, nullptr, kOpaque);
329     this->Update(op->false_branch, nullptr, kOpaque);
330     ExprVisitor::VisitExpr_(op);
331     this->AddNode(op);
332   }
333 
VisitExpr_(const RefCreateNode * op)334   void VisitExpr_(const RefCreateNode* op) final {
335     this->Update(op->value, nullptr, kOpaque);
336     ExprVisitor::VisitExpr_(op);
337     this->AddNode(op);
338   }
339 
VisitExpr_(const RefReadNode * op)340   void VisitExpr_(const RefReadNode* op) final {
341     this->Update(op->ref, nullptr, kOpaque);
342     ExprVisitor::VisitExpr_(op);
343     this->AddNode(op);
344   }
345 
VisitExpr_(const RefWriteNode * op)346   void VisitExpr_(const RefWriteNode* op) final {
347     this->Update(op->ref, nullptr, kOpaque);
348     this->Update(op->value, nullptr, kOpaque);
349     ExprVisitor::VisitExpr_(op);
350     this->AddNode(op);
351   }
352 
VisitExpr_(const MatchNode * op)353   void VisitExpr_(const MatchNode* op) final {
354     this->Update(op->data, nullptr, kOpaque);
355     for (const Clause& c : op->clauses) {
356       this->Update(c->rhs, nullptr, kOpaque);
357     }
358     ExprVisitor::VisitExpr_(op);
359     this->AddNode(op);
360   }
361 };
362 
Create(support::Arena * arena,const Expr & body)363 IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) {
364   return Creator(arena).Prepare(body);
365 }
366 
367 /*!
368  * \brief Dominator tree that represent domination or
369  *  post domination relation of the node.
370  */
371 class DominatorTree {
372  public:
373   /*!
374    * \brief A node in the dominator tree.
375    */
376   struct Node {
377     /*! \brief The node in the tree */
378     IndexedForwardGraph::Node* gnode{nullptr};
379     /*! \brief parent of the tree */
380     Node* parent{nullptr};
381     /*! \brief current depth*/
382     int depth{0};
383     /*! \brief aggregated pattern to parent */
384     OpPatternKind pattern{kOpaque};
385   };
386   // index -> node.
387   std::vector<Node*> nodes;
388   /*!
389    * \brief compute a post dominator relation for a given dataflow graph.
390    * \param arena The arena used for node allocation.
391    * \param graph The graph to be analyzed.
392    * \return The dominator tree of the graph.
393    * \note This algorithm makes use of the fact that graph is DAG,
394    *       and runs a single pass algorithm via LCA (Least Common Ancestor)
395    */
396   static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph);
397 
398  private:
399   // Combine pattern together.
CombinePattern(OpPatternKind lhs,OpPatternKind rhs)400   static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
401     if (lhs > rhs) return lhs;
402     return rhs;
403   }
404   /*!
405    * \brief Find the least common ancestor of the two nodes.
406    * \param lhs The left node.
407    * \param rhs The right node.
408    * \param edge_pattern
409    *        The combined edge pattern across all the parents.
410    * \return The least common ancestor of the two.
411    */
LeastCommonAncestor(Node * lhs,Node * rhs,OpPatternKind * edge_pattern)412   static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) {
413     while (lhs != rhs) {
414       if (lhs == nullptr) return nullptr;
415       if (rhs == nullptr) return nullptr;
416       if (lhs->depth < rhs->depth) {
417         edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
418         rhs = rhs->parent;
419       } else if (rhs->depth < lhs->depth) {
420         edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
421         lhs = lhs->parent;
422       } else {
423         edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
424         edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
425         lhs = lhs->parent;
426         rhs = rhs->parent;
427       }
428     }
429     return lhs;
430   }
431   /*!
432    * \brief Find the least common ancestor of a list of nodes.
433    * \param nodes the nodes.
434    * \param edge_pattern
435    *        The combined edge pattern across all the parents.
436    * \return The least common ancestor of all nodes.
437    */
LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge> & input_nodes,OpPatternKind * edge_pattern)438   Node* LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge>& input_nodes,
439                             OpPatternKind* edge_pattern) {
440     auto link = input_nodes.head;
441     if (link == nullptr) {
442       return nullptr;
443     }
444     auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
445       size_t oindex = edge.node->index;
446       CHECK_LT(oindex, nodes.size());
447       Node* onode = nodes[oindex];
448       CHECK(onode != nullptr);
449       return onode;
450     };
451     Node* parent = get_node(link->value);
452     *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
453     link = link->next;
454     for (; link != nullptr; link = link->next) {
455       parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
456       *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
457     }
458     return parent;
459   }
460   /*!
461    * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node.
462    * \param arena The Arena.
463    * \param gnode An IndexedForwardGraph Node.
464    * \return The DominatorTree Node.
465    */
GetNode(support::Arena * arena,IndexedForwardGraph::Node * gnode)466   Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode) {
467     Node* tnode = arena->make<Node>();
468     tnode->gnode = gnode;
469     if (gnode->extern_ref) {
470       tnode->depth = 1;
471       tnode->parent = nullptr;
472       tnode->pattern = kOpaque;
473     } else {
474       // find the LCAs of all outputs.
475       OpPatternKind pattern = kElemWise;
476       Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
477       tnode->depth = parent ? parent->depth + 1 : 1;
478       tnode->parent = parent;
479       tnode->pattern = pattern;
480     }
481     return tnode;
482   }
483 };
484 
PostDom(support::Arena * arena,const IndexedForwardGraph & graph)485 DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) {
486   DominatorTree tree;
487   tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
488   // reverse topo order
489   for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
490     size_t index = i - 1;
491     tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
492   }
493   return tree;
494 }
495 
496 /*!
497  * \brief A partition of the graph marked by union find data structure.
498  */
499 class GraphPartitioner {
500  public:
GraphPartitioner(support::Arena * arena,int opt_level,size_t max_fuse_depth)501   explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth)
502       : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {}
503   /*!
504    * \brief Group as a union find data structure.
505    */
506   struct Group {
507     /*! \brief The parent in the union find data structure. */
508     Group* parent{nullptr};
509     /*! \brief The pattern of the group */
510     OpPatternKind pattern;
511     /*! \brief reference to the root node. */
512     const tvm::Object* root_ref{nullptr};
513     /*!
514      * \brief Reference to the master node,
515      * this field is not nullptr only if pattern is kOutEWiseFusable.
516      */
517     const tvm::Object* master_ref{nullptr};
518     /*!
519      * \brief Find the group root, perform path compression
520      * \return The root type node.
521      */
FindRoottvm::relay::GraphPartitioner::Group522     Group* FindRoot() {
523       // fast path
524       if (this->parent == nullptr) return this;
525       // slow path with path compression.
526       Group* root = this;
527       while (root->parent != nullptr) {
528         root = root->parent;
529       }
530       for (Group* p = this; p != root;) {
531         Group* parent = p->parent;
532         p->parent = root;
533         p = parent;
534       }
535       return root;
536     }
537 
538     /*!
539      * \brief The number of nodes belonging to this group
540      */
541     uint32_t num_nodes{1};
542   };
543   /*!
544    * \brief Partition a graph.
545    * \return group assignments of each node.
546    */
547   std::vector<Group*> Partition(const IndexedForwardGraph& graph);
548 
549  private:
550   /*! \brief The internal arena for temporary space. */
551   support::Arena* arena_;
552   /*! \brief optimization level for fuse operation. */
553   int opt_level_;
554   /*! \brief The maximum number of operations in one fused function */
555   size_t max_fuse_depth_;
556   /*! \brief The internal groups. */
557   std::vector<Group*> groups_;
558   /*! \brief internal field used for deduplication */
559   std::unordered_set<IndexedForwardGraph::Node*> visited_;
560   // Internal implelementation of CheckPath
561   template <typename F>
CheckPath_(IndexedForwardGraph::Node * src,IndexedForwardGraph::Node * sink,F fcond)562   bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) {
563     if (visited_.count(src)) return true;
564     visited_.insert(src);
565     Group* gnode = groups_[src->index];
566     CHECK(gnode != nullptr);
567     gnode = gnode->FindRoot();
568     if (!fcond(gnode->pattern, src == sink)) return false;
569     if (src == sink) return true;
570     for (auto link = src->outputs.head; link != nullptr; link = link->next) {
571       if (!CheckPath_(link->value.node, sink, fcond)) return false;
572     }
573     return true;
574   }
575   /*!
576    * \brief Check all the node and edge pattern
577    *  between src and sink satisfies fcond.
578    *
579    * src is not checked.
580    *
581    * \param src The source node.
582    * \param sink The termination node.
583    * \param fcond The condition to be checked.
584    * \tparam F the condition function, with signature
585    * \note sink must be a post-dominator of src.
586    */
587   template <typename F>
CheckPath(IndexedForwardGraph::Node * src,IndexedForwardGraph::Node * sink,F fcond)588   bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) {
589     CHECK(!src->extern_ref);
590     visited_.clear();
591     CHECK(src != sink);
592     for (auto link = src->outputs.head; link != nullptr; link = link->next) {
593       if (!CheckPath_(link->value.node, sink, fcond)) return false;
594     }
595     return true;
596   }
597   // Combine two patterns together.
CombinePattern(OpPatternKind lhs,OpPatternKind rhs)598   static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
599     if (lhs > kBroadcast && rhs > kBroadcast) {
600       LOG(FATAL) << "Cannot merge two complex group together";
601     }
602     if (lhs > rhs) return lhs;
603     return rhs;
604   }
605   /*!
606    * \brief Merge the child group to the parent.
607    * \param child The child group.
608    * \param parent The parent group.
609    */
MergeFromTo(Group * child,Group * parent)610   void MergeFromTo(Group* child, Group* parent) {
611     child = child->FindRoot();
612     parent = parent->FindRoot();
613     if (child == parent) return;
614     // update the number of nodes of the parent group
615     parent->num_nodes += child->num_nodes;
616     child->parent = parent;
617     // update master ref and pattern
618     if (child->master_ref != nullptr) {
619       CHECK(parent->master_ref == nullptr);
620       parent->master_ref = child->master_ref;
621       parent->pattern = CombinePattern(child->pattern, parent->pattern);
622     }
623   }
624   // Internal implelementation of CommitFuse
CommitFuse_(IndexedForwardGraph::Node * src,IndexedForwardGraph::Node * sink,Group * target)625   void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target) {
626     if (src == sink) return;
627     if (visited_.count(src)) return;
628     visited_.insert(src);
629     Group* gnode = groups_[src->index];
630     CHECK(gnode != nullptr);
631     // merge the current group to the parent if possible.
632     MergeFromTo(gnode, target);
633     for (auto link = src->outputs.head; link != nullptr; link = link->next) {
634       CommitFuse_(link->value.node, sink, target);
635     }
636   }
637   /*!
638    * \brief Commit fusion operation.
639    * \param src The source node.
640    * \param sink The termination node.
641    * \note sink must be a post-dominator of src.
642    */
CommitFuse(IndexedForwardGraph::Node * src,IndexedForwardGraph::Node * sink)643   void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
644     Group* target = groups_[sink->index];
645     visited_.clear();
646     CHECK(src != sink);
647     CommitFuse_(src, sink, target);
648   }
649 
CountNodesUptoSink_(IndexedForwardGraph::Node * src,IndexedForwardGraph::Node * sink)650   size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
651     if (src == sink || visited_.count(src)) return 0;
652     visited_.insert(src);
653     Group* gnode = groups_[src->index];
654     CHECK(gnode != nullptr);
655     auto sum = gnode->num_nodes;
656     for (auto link = src->outputs.head; link != nullptr; link = link->next) {
657       sum += CountNodesUptoSink_(link->value.node, sink);
658     }
659     return sum;
660   }
661 
662   // Count the number of nodes in a fused subgraph if child is additionaly fused.
663   // dom_parent is already known to be a part of the subgraph.
664   // For a diamond structure, there can be multiple paths connecting child and dom_parent.
665   // All intermediate nodes between child and dom_parent are taken into account.
666   // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot()
667   // is important for correct calculation.
CountFusedNodesWithNewChild(IndexedForwardGraph::Node * child,IndexedForwardGraph::Node * dom_parent)668   size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
669                                      IndexedForwardGraph::Node* dom_parent) {
670     Group* target = groups_[dom_parent->index];
671     visited_.clear();
672     CHECK(child != dom_parent);
673     return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
674   }
675 
676   // Initialize the groups.
InitGroups(const IndexedForwardGraph & graph)677   void InitGroups(const IndexedForwardGraph& graph) {
678     groups_.resize(graph.post_dfs_order.size());
679     for (size_t nid = 0; nid < groups_.size(); ++nid) {
680       const auto* graph_node = graph.post_dfs_order[nid];
681       auto* group_node = arena_->make<Group>();
682       group_node->pattern = graph_node->pattern;
683       group_node->root_ref = graph_node->ref;
684       // set master ref if necessary.
685       if (group_node->pattern == kOutEWiseFusable) {
686         group_node->master_ref = graph_node->ref;
687       }
688       groups_[nid] = group_node;
689     }
690   }
691 
692   // execute the fusion algorithm.
RunFuse(const IndexedForwardGraph & graph,const DominatorTree & post_dom_tree,int phase)693   void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) {
694     for (size_t nid = 0; nid < groups_.size(); ++nid) {
695       // the group of current node has been specified already.
696       auto* graph_node = graph.post_dfs_order[nid];
697       auto* dom_node = post_dom_tree.nodes[nid];
698       Group* group_node = groups_[nid];
699       CHECK(group_node != nullptr);
700       // no actions for opaque nodes
701       if (group_node->pattern == kOpaque) continue;
702       // no actions needed if the current node have no dominator
703       if (dom_node->parent == nullptr) continue;
704       CHECK(!graph_node->extern_ref);
705       size_t dom_parent_gindex = dom_node->parent->gnode->index;
706 
707       // refuse the fusion if too many ops are going to be fused together
708       if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
709         continue;
710 
711       if (phase == 2) {
712         // Fuse injective ops into intermediate tuples, if any
713         if (group_node->pattern > kInjective) continue;
714         Group* dom_parent_group = groups_[dom_parent_gindex];
715         Group* dom_root_group = dom_parent_group->FindRoot();
716         // If dom node group has a tuple as its root, we do not fuse tuple fields into it
717         if (dom_root_group->pattern == kTuple) continue;
718         if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) {
719           // Now we know the tuple has been fused into subsequent injective ops
720           auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
721           // dom_root_group can also be tuple, as in inception layers
722           // CheckPath is needed to avoid fusing two intermediate tuples
723           if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
724             CommitFuse(graph_node, dom_node->parent->gnode);
725           }
726         }
727         continue;
728       }
729 
730       // Skip if current node is already fused to the parent.
731       if (groups_[dom_parent_gindex] != nullptr &&
732           group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
733         continue;
734       }
735       // Do not fuse into tuple for now
736       if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
737       // Try to fuse current node to its post-dominator.
738       if (group_node->pattern == kOutEWiseFusable) {
739         if (phase != 0) continue;
740         // Path for OutEWiseFusable: conv2d
741         // Check if the dominator relation is elemwise.
742         if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
743           CHECK(dom_node->parent->gnode != nullptr);
744           // The fuse can be executed if all the intermediate ops are still broadcast.
745           auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
746           if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
747             CommitFuse(graph_node, dom_node->parent->gnode);
748           }
749         }
750       } else if (group_node->pattern <= kBroadcast) {
751         // Pre-condition: can only be fused to parent which is injective or reduction.
752         if (dom_node->parent != nullptr &&
753             (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
754           // Check if all the intermediate ops are still broadcast.
755           // The final terminal node can already be fused to a OutEWiseFusable group.
756           auto fcond = [](OpPatternKind kind, bool is_sink) {
757             if (!is_sink) {
758               // Elemwise, broadcast, and injective ops on the parallel branches
759               // are allowed be fused to the elemwise/broadcast master.
760               return kind <= kInjective;
761             } else {
762               return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
763                       kind == kOutEWiseFusable);
764             }
765           };
766           if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
767             CommitFuse(graph_node, dom_node->parent->gnode);
768           }
769         }
770       } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
771         // defer injective fusion to second phase.
772         // so conv2d always finishes fusing.
773         if (phase != 1) continue;
774         // Check if all path are injective.
775         auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
776         if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
777           CommitFuse(graph_node, dom_node->parent->gnode);
778         }
779       } else {
780         // do nothing.
781         CHECK(group_node->pattern == kCommReduce);
782       }
783     }
784   }
785 };
786 
Partition(const IndexedForwardGraph & graph)787 std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
788     const IndexedForwardGraph& graph) {
789   this->InitGroups(graph);
790   if (opt_level_ == 0) return std::move(groups_);
791   // get post dominator tree
792   auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
793   // run fusion algorithm.
794   for (int phase = 0; phase < 3; ++phase) {
795     this->RunFuse(graph, post_dom_tree, phase);
796   }
797   return std::move(groups_);
798 }
799 
800 class FuseMutator : private ExprMutator {
801  public:
802   // Run the transform
Transform(const Expr & body,int fuse_opt_level,size_t max_fuse_depth)803   Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) {
804     // setup the group map.
805     auto graph = IndexedForwardGraph::Create(&arena_, body);
806     auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph);
807     for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
808       CHECK(graph.post_dfs_order[nid]->ref != nullptr);
809       gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
810     }
811     // The following line can be used for debug.
812     // this->DebugDumpGroup(body);
813     return this->Mutate(body);
814   }
815 
816  private:
817   /*! \brief Temporary information from each group. */
818   struct GroupInfo {
819    public:
820     // The parameters of the function.
821     Array<Var> params;
822     // The arguments to call the functions.
823     Array<Expr> arguments;
824     // Get a new parameter or allocate an old one
GetOrAllocParamtvm::relay::FuseMutator::GroupInfo825     Var GetOrAllocParam(const Expr& expr, const Type& type) {
826       // run linear scan as most fused groups contain only a few inputs.
827       for (size_t i = 0; i < arguments.size(); ++i) {
828         if (expr.same_as(arguments[i])) return params[i];
829       }
830       // create a new parameter.
831       std::ostringstream os;
832       os << "p" << params.size();
833       auto var = Var(os.str(), type);
834       params.push_back(var);
835       arguments.push_back(expr);
836       return var;
837     }
838   };
839   /*! \brief Internal arena. */
840   support::Arena arena_;
841   /*! \brief The group assignment map. */
842   std::unordered_map<const Object*, GraphPartitioner::Group*> gmap_;
843   /* \brief Internal group information map. */
844   std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_;
845 
846   // Skip primitive function.
VisitExpr_(const FunctionNode * fn_node)847   Expr VisitExpr_(const FunctionNode* fn_node) {
848     if (fn_node->HasNonzeroAttr(attr::kPrimitive)) {
849       return GetRef<Expr>(fn_node);
850     } else {
851       return ExprMutator::VisitExpr_(fn_node);
852     }
853   }
854 
855   // Transform calls.
VisitExpr_(const CallNode * call)856   Expr VisitExpr_(const CallNode* call) {
857     if (call->op.as<OpNode>()) {
858       static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
859 
860       if (fnoncomputational.get(Downcast<Op>(call->op), false)) {
861         return ExprMutator::VisitExpr_(call);
862       }
863 
864       // If it is a primitive op call
865       // then we must have a group assignment for it already.
866       CHECK(gmap_.count(call));
867       if (call->op == stop_fusion_op) {
868         return ExprMutator::VisitExpr(call->args[0]);
869       }
870       auto* ret_group = gmap_.at(call)->FindRoot();
871       Array<Expr> new_args = GetNewArguments(call->args, ret_group);
872 
873       auto new_call = Call(call->op, new_args, call->attrs, call->type_args);
874 
875       if (ret_group->root_ref == call) {
876         // This is the root of the group
877         // create the new call node.
878         return MakeNewFunction(ret_group, call->checked_type(), new_call);
879       } else {
880         // This is an intermediate node of a fused function
881         // simply return the new call.
882         return std::move(new_call);
883       }
884     } else {
885       return ExprMutator::VisitExpr_(call);
886     }
887   }
888 
VisitExpr_(const TupleNode * tuple)889   Expr VisitExpr_(const TupleNode* tuple) {
890     auto* ret_group = gmap_.at(tuple)->FindRoot();
891     if (ret_group->root_ref == tuple) {
892       return ExprMutator::VisitExpr_(tuple);
893     }
894     // This tuple is an intermediate node in the group
895     Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
896     return Tuple(new_fields);
897   }
898 
VisitExpr_(const TupleGetItemNode * tuple_get)899   Expr VisitExpr_(const TupleGetItemNode* tuple_get) {
900     auto* ret_group = gmap_.at(tuple_get)->FindRoot();
901     auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
902     auto new_node = TupleGetItem(new_tuple, tuple_get->index);
903     if (ret_group->root_ref == tuple_get) {
904       if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
905         // Isolated. This case occurs when tuple is created by an Opaque op
906         // e.g. multibox_transform_loc
907         return ExprMutator::VisitExpr_(tuple_get);
908       }
909       // A new function whose output is a tuple field access
910       return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node);
911     }
912     // This is an intermediate node in the group
913     return std::move(new_node);
914   }
915 
MakeNewFunction(GraphPartitioner::Group * group,Type ret_type,Expr body)916   Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
917     // If the function has no call, it is not a primitive function.
918     struct HasCallVisitor : ExprVisitor {
919       bool has_call = false;
920       void VisitExpr_(const CallNode* op) final { has_call = true; }
921     } visitor;
922     visitor(body);
923     const GroupInfo& ginfo = ginfo_[group];
924     auto func = Function(ginfo.params, body, ret_type, {});
925     func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call));
926     return Call(func, ginfo.arguments, Attrs());
927   }
928 
GetNewArguments(const tvm::Array<Expr> & args,GraphPartitioner::Group * current_group)929   Array<Expr> GetNewArguments(const tvm::Array<Expr>& args,
930                               GraphPartitioner::Group* current_group) {
931     Array<Expr> new_args;
932     for (auto arg : args) {
933       auto* arg_group = gmap_.at(arg.get())->FindRoot();
934       auto type = arg->checked_type();
935       Expr new_arg = this->Mutate(arg);
936       if (current_group != arg_group) {
937         Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
938         new_args.push_back(param);
939       } else {
940         new_args.push_back(new_arg);
941       }
942     }
943     return new_args;
944   }
945 
946   // Debug function, dump the group assignment in text.
DebugDumpGroup(const Expr & body)947   void DebugDumpGroup(const Expr& body) {
948     std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string {
949       auto it = gmap_.find(expr.get());
950       if (it == gmap_.end()) return "";
951       std::ostringstream os;
952       auto* group = it->second->FindRoot();
953       os << " /* group=" << group << " */";
954       return os.str();
955     });
956     LOG(INFO) << "Dump of group info:\n" << text;
957   }
958 };
959 
FuseOps(const Expr & expr,int fuse_opt_level,size_t max_fuse_depth,const IRModule & module)960 Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, const IRModule& module) {
961   return FuseMutator().Transform(expr, fuse_opt_level, max_fuse_depth);
962 }
963 
964 namespace transform {
965 
FuseOps(int fuse_opt_level)966 Pass FuseOps(int fuse_opt_level) {
967   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
968       [=](Function f, IRModule m, PassContext pc) {
969         int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
970         auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps));
971         return Downcast<Function>(FuseOps(f, opt_level, max_fuse_depth.value(), m));
972       };
973   return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
974 }
975 
976 TVM_REGISTER_GLOBAL("relay._transform.FuseOps").set_body_typed(FuseOps);
977 
978 }  // namespace transform
979 
980 }  // namespace relay
981 }  // namespace tvm
982