1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 *
22 * \file src/relay/transforms/fuse_ops.cc
23 *
24 * \brief This is a backend-aware optimization pass.
25 * Fuse necessary ops into a single one.
26 */
27 #include <tvm/relay/analysis.h>
28 #include <tvm/relay/expr_functor.h>
29 #include <tvm/relay/op_attr_types.h>
30 #include <tvm/relay/transform.h>
31 #include <tvm/tir/op.h>
32
33 #include "../../support/arena.h"
34 #include "pass_util.h"
35 #include "pattern_util.h"
36
37 namespace tvm {
38 namespace relay {
39
40 /*
41 Note on Fusing algorithm:
42
43 The main challenge of general fusor is to handle possible diamond shape branches,
44 in the following graph, conv2d can be fused to elemwise add.
45
46 conv2d
47 / | \
48 / | \
49 op op op
50 \ | /
51 \ | /
52 elemwise add
53 |
54
55 However, at the point of conv2d we do not necessarily know that all the future paths
56 will merge at the elemwise add. The fusion algorithm applies post-dominator analysis.
57
58 The immediate post-dominator of a node defined by the closest node where all the future path goes
59 into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm
60 is as follows:
61
62 - Construct a DAG of dataflow graph for dominator analysis
63 - Construct a post-dominator tree which gives immediate post dominator of each node.
64 - Run fusion algorithm with the given post-dominator information.
65
66 Note that, because we run analysis on a DAG, we use a single pass post-dominator
67 tree construction algorithm via LCA, which is simpler than the full version that handles cycles.
68
69 The fusion algorithm traverses from each node and checks if it can be fused to its
70 immediate post dominator. It has to check the following things:
71
72 - CheckPath: check all the path between a node and its immediate post-dominator
73 satisfies the fuse condition.
74 - Note that these intermediate node can already be fused with another nodes, the algorithm
75 will still run correctly.
76 - CommitFuse: mark all the nodes between source and post-dominator as the same group.
77 - We use an Union-Find data structure to manage the groups.
78 */
79 using support::LinkedList;
80 using support::LinkNode;
81
82 constexpr uint32_t kMaxFusedOps = 256;
83
84 static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion");
85
86 TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer);
87
88 /*!
89 * \brief Indexed data flow graph in forward direction.
90 * This is a temporary data structure used for operator fusion analysis.
91 *
92 * This data structure only captures the dataflow fragment and
93 * could ignore blocks like let by simply ordering each dataflow block
94 * and mark the output node as extern_ref;
95 */
96 class IndexedForwardGraph {
97 public:
98 struct Node;
99 /*!
100 * The forward edge in the dataflow graph.
101 */
102 struct Edge {
103 /*! \brief The corresponding node */
104 Node* node{nullptr};
105 /*! \brief The respective pattern of this op */
106 OpPatternKind pattern{kOpaque};
107 };
108 /*! \brief A node in the graph. */
109 struct Node {
110 /*! \brief weak reference to the corresponding edge. */
111 const tvm::Object* ref{nullptr};
112 /*! \brief The index of the node in topological order. */
113 size_t index{0};
114 /*! \brief Whether this node is referenced by external source */
115 bool extern_ref{false};
116 /*! \brief The general pattern in the node */
117 OpPatternKind pattern{kOpaque};
118 /*! \brief The outputs of the node. */
119 LinkedList<Edge> outputs;
120 };
121 /*! \brief The node map that maps node to graph */
122 std::unordered_map<const tvm::Object*, Node*> node_map;
123 /*! \brief All the nodes in post DFS order */
124 std::vector<Node*> post_dfs_order;
125
126 /*! \brief Dump the graph into string. */
DebugDump()127 void DebugDump() {
128 std::ostringstream os;
129 for (size_t i = 0; i < post_dfs_order.size(); ++i) {
130 Node* node = post_dfs_order[i];
131 os << "node[" << i << "], " << GetRef<ObjectRef>(node->ref) << " outputs=[";
132 for (auto* link = node->outputs.head; link != nullptr; link = link->next) {
133 os << link->value.node->index << ", ";
134 }
135 os << "]\n";
136 }
137 LOG(INFO) << os.str();
138 }
139 /*!
140 * \brief create a indexed forward graph.
141 * \param arena The arena used for data allocation.
142 * \param body The body of the expression to create a graph.
143 */
144 static IndexedForwardGraph Create(support::Arena* arena, const Expr& body);
145
146 private:
147 class Creator;
148 };
149
150 // Creator of post dominator tree of the dataflow
151 class IndexedForwardGraph::Creator : private ExprVisitor {
152 public:
Creator(support::Arena * arena)153 explicit Creator(support::Arena* arena) : arena_(arena) {}
154
Prepare(const Expr & body)155 IndexedForwardGraph Prepare(const Expr& body) {
156 this->Update(body, nullptr, kOpaque);
157 this->VisitExpr(body);
158 return std::move(graph_);
159 }
160
161 private:
162 /*! \brief allocator of all the internal node object */
163 support::Arena* arena_;
164 // The output.
165 IndexedForwardGraph graph_;
166 // attribute equal comparator
167 StructuralEqual attr_equal_;
168 // Update the message stored at the node.
Update(const Expr & node,IndexedForwardGraph::Node * parent,OpPatternKind pattern)169 void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) {
170 const tvm::Object* key = node.get();
171 IndexedForwardGraph::Node* current;
172 auto it = graph_.node_map.find(key);
173 if (it != graph_.node_map.end()) {
174 current = it->second;
175 } else {
176 current = arena_->make<IndexedForwardGraph::Node>();
177 graph_.node_map[key] = current;
178 }
179 if (parent != nullptr) {
180 auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge> >();
181 link->value.node = parent;
182 link->value.pattern = pattern;
183 current->outputs.Push(link);
184 } else {
185 current->extern_ref = true;
186 }
187 }
188
AddNode(const tvm::Object * key)189 void AddNode(const tvm::Object* key) {
190 auto it = graph_.node_map.find(key);
191 CHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef<ObjectRef>(key);
192 IndexedForwardGraph::Node* node = it->second;
193 CHECK(node->ref == nullptr);
194 node->ref = key;
195 node->index = graph_.post_dfs_order.size();
196 graph_.post_dfs_order.push_back(node);
197 }
198
199 // Post order tree
VisitExpr_(const FunctionNode * op)200 void VisitExpr_(const FunctionNode* op) final {
201 // Skip the function that should be handled by external codegen.
202 if (op->GetAttr<String>(attr::kCompiler).defined()) return;
203
204 for (auto param : op->params) {
205 this->Update(param, nullptr, kOpaque);
206 }
207 this->Update(op->body, nullptr, kOpaque);
208 ExprVisitor::VisitExpr_(op);
209 }
210
VisitExpr_(const ConstantNode * op)211 void VisitExpr_(const ConstantNode* op) final {
212 this->AddNode(op);
213 Node* node = graph_.node_map.at(op);
214 DataType dtype = DataType(op->data->dtype);
215 // This rule must be consistent with code generator.
216 bool is_simple_const =
217 (dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) ||
218 dtype == DataType::Float(64) || dtype == DataType::Bool());
219 if (op->is_scalar() && is_simple_const) {
220 node->pattern = kElemWise;
221 } else {
222 // for now, mark non-scalar constant
223 // as opaque, we will not choose to fuse it.
224 node->pattern = kOpaque;
225 }
226 }
227
VisitExpr_(const CallNode * call)228 void VisitExpr_(const CallNode* call) final {
229 CHECK(graph_.node_map.count(call));
230 Node* node = graph_.node_map.at(call);
231 static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
232 // Now we set the pattern of this call.
233 //
234 // If we see a call mentioning an operator we should mark it with its
235 // annotated pattern.
236 //
237 // If the pattern is not annotated we will default to opaque.
238 //
239 // Finally if the operator position is not a call node we will
240 // need to call Update, as it may be an arbitrary expression.
241 OpPatternKind op_pattern = kOpaque;
242 if (const OpNode* opnode = call->op.as<OpNode>()) {
243 auto op = GetRef<Op>(opnode);
244 if (IsDynamic(call->checked_type()) && IsDataDependant(call)) {
245 // output of a shape func can't be fed to a data-dependent shape func
246 op_pattern = kOpaque;
247 } else {
248 op_pattern = static_cast<OpPatternKind>(fpattern[op]);
249 }
250 } else {
251 this->Update(call->op, node, kOpaque);
252 }
253
254 node->pattern = op_pattern;
255 this->Update(call->op, nullptr, kOpaque);
256 const auto* rtype = call->checked_type().as<TensorTypeNode>();
257 // pass the analysis back to all the children it references.
258 for (size_t i = 0; i < call->args.size(); ++i) {
259 const auto* arg_type = call->args[i]->checked_type().as<TensorTypeNode>();
260 // specifically check if result type is the same as arguments type
261 OpPatternKind edge_pattern = op_pattern;
262 if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr &&
263 attr_equal_(rtype->shape, arg_type->shape)) {
264 edge_pattern = kElemWise;
265 }
266 this->Update(call->args[i], node, edge_pattern);
267 }
268 ExprVisitor::VisitExpr_(call);
269 this->AddNode(call);
270 }
271
VisitExpr_(const TupleNode * op)272 void VisitExpr_(const TupleNode* op) final {
273 CHECK(graph_.node_map.count(op));
274 Node* tuple_node = graph_.node_map.at(op);
275 tuple_node->pattern = kTuple;
276 for (const Expr& field : op->fields) {
277 if (field->checked_type().as<TensorTypeNode>()) {
278 this->Update(field, tuple_node, kInjective);
279 } else {
280 this->Update(field, nullptr, kOpaque);
281 }
282 }
283 ExprVisitor::VisitExpr_(op);
284 this->AddNode(op);
285 }
286
VisitExpr_(const TupleGetItemNode * op)287 void VisitExpr_(const TupleGetItemNode* op) final {
288 auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>();
289 CHECK(tuple_type);
290 // When TVM lowers a fused function, it expects all arguments to be a Tensor or
291 // a tuple containing only Tensors. But this tuple may contain a reference or
292 // another tuple. To avoid modifying codegen logic, we do not allow fusing through this node
293 // if the tuple contains such non Tensor fields. However, all fields will be recursively
294 // visited via call to ExprVisitor::VisitExpr_(op) below and corresponding visitor methods.
295 bool has_non_tensor = false;
296 for (auto ty : tuple_type->fields) {
297 if (!ty.as<TensorTypeNode>()) {
298 has_non_tensor = true;
299 break;
300 }
301 }
302 if (has_non_tensor) {
303 this->Update(op->tuple, nullptr, kOpaque);
304 } else {
305 CHECK(graph_.node_map.count(op));
306 Node* node = graph_.node_map.at(op);
307 node->pattern = kInjective;
308 this->Update(op->tuple, node, kInjective);
309 }
310 ExprVisitor::VisitExpr_(op);
311 this->AddNode(op);
312 }
313
VisitExpr_(const VarNode * op)314 void VisitExpr_(const VarNode* op) final { this->AddNode(op); }
315
VisitExpr_(const LetNode * op)316 void VisitExpr_(const LetNode* op) final {
317 // do not fuse through let.
318 this->Update(op->var, nullptr, kOpaque);
319 this->Update(op->value, nullptr, kOpaque);
320 this->Update(op->body, nullptr, kOpaque);
321 ExprVisitor::VisitExpr_(op);
322 this->AddNode(op);
323 }
324
VisitExpr_(const IfNode * op)325 void VisitExpr_(const IfNode* op) final {
326 // do not fuse through if.
327 this->Update(op->cond, nullptr, kOpaque);
328 this->Update(op->true_branch, nullptr, kOpaque);
329 this->Update(op->false_branch, nullptr, kOpaque);
330 ExprVisitor::VisitExpr_(op);
331 this->AddNode(op);
332 }
333
VisitExpr_(const RefCreateNode * op)334 void VisitExpr_(const RefCreateNode* op) final {
335 this->Update(op->value, nullptr, kOpaque);
336 ExprVisitor::VisitExpr_(op);
337 this->AddNode(op);
338 }
339
VisitExpr_(const RefReadNode * op)340 void VisitExpr_(const RefReadNode* op) final {
341 this->Update(op->ref, nullptr, kOpaque);
342 ExprVisitor::VisitExpr_(op);
343 this->AddNode(op);
344 }
345
VisitExpr_(const RefWriteNode * op)346 void VisitExpr_(const RefWriteNode* op) final {
347 this->Update(op->ref, nullptr, kOpaque);
348 this->Update(op->value, nullptr, kOpaque);
349 ExprVisitor::VisitExpr_(op);
350 this->AddNode(op);
351 }
352
VisitExpr_(const MatchNode * op)353 void VisitExpr_(const MatchNode* op) final {
354 this->Update(op->data, nullptr, kOpaque);
355 for (const Clause& c : op->clauses) {
356 this->Update(c->rhs, nullptr, kOpaque);
357 }
358 ExprVisitor::VisitExpr_(op);
359 this->AddNode(op);
360 }
361 };
362
Create(support::Arena * arena,const Expr & body)363 IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) {
364 return Creator(arena).Prepare(body);
365 }
366
367 /*!
368 * \brief Dominator tree that represent domination or
369 * post domination relation of the node.
370 */
371 class DominatorTree {
372 public:
373 /*!
374 * \brief A node in the dominator tree.
375 */
376 struct Node {
377 /*! \brief The node in the tree */
378 IndexedForwardGraph::Node* gnode{nullptr};
379 /*! \brief parent of the tree */
380 Node* parent{nullptr};
381 /*! \brief current depth*/
382 int depth{0};
383 /*! \brief aggregated pattern to parent */
384 OpPatternKind pattern{kOpaque};
385 };
386 // index -> node.
387 std::vector<Node*> nodes;
388 /*!
389 * \brief compute a post dominator relation for a given dataflow graph.
390 * \param arena The arena used for node allocation.
391 * \param graph The graph to be analyzed.
392 * \return The dominator tree of the graph.
393 * \note This algorithm makes use of the fact that graph is DAG,
394 * and runs a single pass algorithm via LCA (Least Common Ancestor)
395 */
396 static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph);
397
398 private:
399 // Combine pattern together.
CombinePattern(OpPatternKind lhs,OpPatternKind rhs)400 static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
401 if (lhs > rhs) return lhs;
402 return rhs;
403 }
404 /*!
405 * \brief Find the least common ancestor of the two nodes.
406 * \param lhs The left node.
407 * \param rhs The right node.
408 * \param edge_pattern
409 * The combined edge pattern across all the parents.
410 * \return The least common ancestor of the two.
411 */
LeastCommonAncestor(Node * lhs,Node * rhs,OpPatternKind * edge_pattern)412 static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) {
413 while (lhs != rhs) {
414 if (lhs == nullptr) return nullptr;
415 if (rhs == nullptr) return nullptr;
416 if (lhs->depth < rhs->depth) {
417 edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
418 rhs = rhs->parent;
419 } else if (rhs->depth < lhs->depth) {
420 edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
421 lhs = lhs->parent;
422 } else {
423 edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
424 edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
425 lhs = lhs->parent;
426 rhs = rhs->parent;
427 }
428 }
429 return lhs;
430 }
431 /*!
432 * \brief Find the least common ancestor of a list of nodes.
433 * \param nodes the nodes.
434 * \param edge_pattern
435 * The combined edge pattern across all the parents.
436 * \return The least common ancestor of all nodes.
437 */
LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge> & input_nodes,OpPatternKind * edge_pattern)438 Node* LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge>& input_nodes,
439 OpPatternKind* edge_pattern) {
440 auto link = input_nodes.head;
441 if (link == nullptr) {
442 return nullptr;
443 }
444 auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
445 size_t oindex = edge.node->index;
446 CHECK_LT(oindex, nodes.size());
447 Node* onode = nodes[oindex];
448 CHECK(onode != nullptr);
449 return onode;
450 };
451 Node* parent = get_node(link->value);
452 *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
453 link = link->next;
454 for (; link != nullptr; link = link->next) {
455 parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
456 *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
457 }
458 return parent;
459 }
460 /*!
461 * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node.
462 * \param arena The Arena.
463 * \param gnode An IndexedForwardGraph Node.
464 * \return The DominatorTree Node.
465 */
GetNode(support::Arena * arena,IndexedForwardGraph::Node * gnode)466 Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode) {
467 Node* tnode = arena->make<Node>();
468 tnode->gnode = gnode;
469 if (gnode->extern_ref) {
470 tnode->depth = 1;
471 tnode->parent = nullptr;
472 tnode->pattern = kOpaque;
473 } else {
474 // find the LCAs of all outputs.
475 OpPatternKind pattern = kElemWise;
476 Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
477 tnode->depth = parent ? parent->depth + 1 : 1;
478 tnode->parent = parent;
479 tnode->pattern = pattern;
480 }
481 return tnode;
482 }
483 };
484
PostDom(support::Arena * arena,const IndexedForwardGraph & graph)485 DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) {
486 DominatorTree tree;
487 tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
488 // reverse topo order
489 for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
490 size_t index = i - 1;
491 tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
492 }
493 return tree;
494 }
495
496 /*!
497 * \brief A partition of the graph marked by union find data structure.
498 */
499 class GraphPartitioner {
500 public:
GraphPartitioner(support::Arena * arena,int opt_level,size_t max_fuse_depth)501 explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth)
502 : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {}
503 /*!
504 * \brief Group as a union find data structure.
505 */
506 struct Group {
507 /*! \brief The parent in the union find data structure. */
508 Group* parent{nullptr};
509 /*! \brief The pattern of the group */
510 OpPatternKind pattern;
511 /*! \brief reference to the root node. */
512 const tvm::Object* root_ref{nullptr};
513 /*!
514 * \brief Reference to the master node,
515 * this field is not nullptr only if pattern is kOutEWiseFusable.
516 */
517 const tvm::Object* master_ref{nullptr};
518 /*!
519 * \brief Find the group root, perform path compression
520 * \return The root type node.
521 */
FindRoottvm::relay::GraphPartitioner::Group522 Group* FindRoot() {
523 // fast path
524 if (this->parent == nullptr) return this;
525 // slow path with path compression.
526 Group* root = this;
527 while (root->parent != nullptr) {
528 root = root->parent;
529 }
530 for (Group* p = this; p != root;) {
531 Group* parent = p->parent;
532 p->parent = root;
533 p = parent;
534 }
535 return root;
536 }
537
538 /*!
539 * \brief The number of nodes belonging to this group
540 */
541 uint32_t num_nodes{1};
542 };
543 /*!
544 * \brief Partition a graph.
545 * \return group assignments of each node.
546 */
547 std::vector<Group*> Partition(const IndexedForwardGraph& graph);
548
549 private:
550 /*! \brief The internal arena for temporary space. */
551 support::Arena* arena_;
552 /*! \brief optimization level for fuse operation. */
553 int opt_level_;
554 /*! \brief The maximum number of operations in one fused function */
555 size_t max_fuse_depth_;
556 /*! \brief The internal groups. */
557 std::vector<Group*> groups_;
558 /*! \brief internal field used for deduplication */
559 std::unordered_set<IndexedForwardGraph::Node*> visited_;
560 // Internal implelementation of CheckPath
561 template <typename F>
CheckPath_(IndexedForwardGraph::Node * src,IndexedForwardGraph::Node * sink,F fcond)562 bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) {
563 if (visited_.count(src)) return true;
564 visited_.insert(src);
565 Group* gnode = groups_[src->index];
566 CHECK(gnode != nullptr);
567 gnode = gnode->FindRoot();
568 if (!fcond(gnode->pattern, src == sink)) return false;
569 if (src == sink) return true;
570 for (auto link = src->outputs.head; link != nullptr; link = link->next) {
571 if (!CheckPath_(link->value.node, sink, fcond)) return false;
572 }
573 return true;
574 }
575 /*!
576 * \brief Check all the node and edge pattern
577 * between src and sink satisfies fcond.
578 *
579 * src is not checked.
580 *
581 * \param src The source node.
582 * \param sink The termination node.
583 * \param fcond The condition to be checked.
584 * \tparam F the condition function, with signature
585 * \note sink must be a post-dominator of src.
586 */
587 template <typename F>
CheckPath(IndexedForwardGraph::Node * src,IndexedForwardGraph::Node * sink,F fcond)588 bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) {
589 CHECK(!src->extern_ref);
590 visited_.clear();
591 CHECK(src != sink);
592 for (auto link = src->outputs.head; link != nullptr; link = link->next) {
593 if (!CheckPath_(link->value.node, sink, fcond)) return false;
594 }
595 return true;
596 }
597 // Combine two patterns together.
CombinePattern(OpPatternKind lhs,OpPatternKind rhs)598 static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
599 if (lhs > kBroadcast && rhs > kBroadcast) {
600 LOG(FATAL) << "Cannot merge two complex group together";
601 }
602 if (lhs > rhs) return lhs;
603 return rhs;
604 }
605 /*!
606 * \brief Merge the child group to the parent.
607 * \param child The child group.
608 * \param parent The parent group.
609 */
MergeFromTo(Group * child,Group * parent)610 void MergeFromTo(Group* child, Group* parent) {
611 child = child->FindRoot();
612 parent = parent->FindRoot();
613 if (child == parent) return;
614 // update the number of nodes of the parent group
615 parent->num_nodes += child->num_nodes;
616 child->parent = parent;
617 // update master ref and pattern
618 if (child->master_ref != nullptr) {
619 CHECK(parent->master_ref == nullptr);
620 parent->master_ref = child->master_ref;
621 parent->pattern = CombinePattern(child->pattern, parent->pattern);
622 }
623 }
624 // Internal implelementation of CommitFuse
CommitFuse_(IndexedForwardGraph::Node * src,IndexedForwardGraph::Node * sink,Group * target)625 void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target) {
626 if (src == sink) return;
627 if (visited_.count(src)) return;
628 visited_.insert(src);
629 Group* gnode = groups_[src->index];
630 CHECK(gnode != nullptr);
631 // merge the current group to the parent if possible.
632 MergeFromTo(gnode, target);
633 for (auto link = src->outputs.head; link != nullptr; link = link->next) {
634 CommitFuse_(link->value.node, sink, target);
635 }
636 }
637 /*!
638 * \brief Commit fusion operation.
639 * \param src The source node.
640 * \param sink The termination node.
641 * \note sink must be a post-dominator of src.
642 */
CommitFuse(IndexedForwardGraph::Node * src,IndexedForwardGraph::Node * sink)643 void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
644 Group* target = groups_[sink->index];
645 visited_.clear();
646 CHECK(src != sink);
647 CommitFuse_(src, sink, target);
648 }
649
CountNodesUptoSink_(IndexedForwardGraph::Node * src,IndexedForwardGraph::Node * sink)650 size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
651 if (src == sink || visited_.count(src)) return 0;
652 visited_.insert(src);
653 Group* gnode = groups_[src->index];
654 CHECK(gnode != nullptr);
655 auto sum = gnode->num_nodes;
656 for (auto link = src->outputs.head; link != nullptr; link = link->next) {
657 sum += CountNodesUptoSink_(link->value.node, sink);
658 }
659 return sum;
660 }
661
662 // Count the number of nodes in a fused subgraph if child is additionaly fused.
663 // dom_parent is already known to be a part of the subgraph.
664 // For a diamond structure, there can be multiple paths connecting child and dom_parent.
665 // All intermediate nodes between child and dom_parent are taken into account.
666 // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot()
667 // is important for correct calculation.
CountFusedNodesWithNewChild(IndexedForwardGraph::Node * child,IndexedForwardGraph::Node * dom_parent)668 size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
669 IndexedForwardGraph::Node* dom_parent) {
670 Group* target = groups_[dom_parent->index];
671 visited_.clear();
672 CHECK(child != dom_parent);
673 return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
674 }
675
676 // Initialize the groups.
InitGroups(const IndexedForwardGraph & graph)677 void InitGroups(const IndexedForwardGraph& graph) {
678 groups_.resize(graph.post_dfs_order.size());
679 for (size_t nid = 0; nid < groups_.size(); ++nid) {
680 const auto* graph_node = graph.post_dfs_order[nid];
681 auto* group_node = arena_->make<Group>();
682 group_node->pattern = graph_node->pattern;
683 group_node->root_ref = graph_node->ref;
684 // set master ref if necessary.
685 if (group_node->pattern == kOutEWiseFusable) {
686 group_node->master_ref = graph_node->ref;
687 }
688 groups_[nid] = group_node;
689 }
690 }
691
692 // execute the fusion algorithm.
RunFuse(const IndexedForwardGraph & graph,const DominatorTree & post_dom_tree,int phase)693 void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) {
694 for (size_t nid = 0; nid < groups_.size(); ++nid) {
695 // the group of current node has been specified already.
696 auto* graph_node = graph.post_dfs_order[nid];
697 auto* dom_node = post_dom_tree.nodes[nid];
698 Group* group_node = groups_[nid];
699 CHECK(group_node != nullptr);
700 // no actions for opaque nodes
701 if (group_node->pattern == kOpaque) continue;
702 // no actions needed if the current node have no dominator
703 if (dom_node->parent == nullptr) continue;
704 CHECK(!graph_node->extern_ref);
705 size_t dom_parent_gindex = dom_node->parent->gnode->index;
706
707 // refuse the fusion if too many ops are going to be fused together
708 if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
709 continue;
710
711 if (phase == 2) {
712 // Fuse injective ops into intermediate tuples, if any
713 if (group_node->pattern > kInjective) continue;
714 Group* dom_parent_group = groups_[dom_parent_gindex];
715 Group* dom_root_group = dom_parent_group->FindRoot();
716 // If dom node group has a tuple as its root, we do not fuse tuple fields into it
717 if (dom_root_group->pattern == kTuple) continue;
718 if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) {
719 // Now we know the tuple has been fused into subsequent injective ops
720 auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
721 // dom_root_group can also be tuple, as in inception layers
722 // CheckPath is needed to avoid fusing two intermediate tuples
723 if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
724 CommitFuse(graph_node, dom_node->parent->gnode);
725 }
726 }
727 continue;
728 }
729
730 // Skip if current node is already fused to the parent.
731 if (groups_[dom_parent_gindex] != nullptr &&
732 group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
733 continue;
734 }
735 // Do not fuse into tuple for now
736 if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
737 // Try to fuse current node to its post-dominator.
738 if (group_node->pattern == kOutEWiseFusable) {
739 if (phase != 0) continue;
740 // Path for OutEWiseFusable: conv2d
741 // Check if the dominator relation is elemwise.
742 if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
743 CHECK(dom_node->parent->gnode != nullptr);
744 // The fuse can be executed if all the intermediate ops are still broadcast.
745 auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
746 if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
747 CommitFuse(graph_node, dom_node->parent->gnode);
748 }
749 }
750 } else if (group_node->pattern <= kBroadcast) {
751 // Pre-condition: can only be fused to parent which is injective or reduction.
752 if (dom_node->parent != nullptr &&
753 (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
754 // Check if all the intermediate ops are still broadcast.
755 // The final terminal node can already be fused to a OutEWiseFusable group.
756 auto fcond = [](OpPatternKind kind, bool is_sink) {
757 if (!is_sink) {
758 // Elemwise, broadcast, and injective ops on the parallel branches
759 // are allowed be fused to the elemwise/broadcast master.
760 return kind <= kInjective;
761 } else {
762 return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
763 kind == kOutEWiseFusable);
764 }
765 };
766 if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
767 CommitFuse(graph_node, dom_node->parent->gnode);
768 }
769 }
770 } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
771 // defer injective fusion to second phase.
772 // so conv2d always finishes fusing.
773 if (phase != 1) continue;
774 // Check if all path are injective.
775 auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
776 if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
777 CommitFuse(graph_node, dom_node->parent->gnode);
778 }
779 } else {
780 // do nothing.
781 CHECK(group_node->pattern == kCommReduce);
782 }
783 }
784 }
785 };
786
Partition(const IndexedForwardGraph & graph)787 std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
788 const IndexedForwardGraph& graph) {
789 this->InitGroups(graph);
790 if (opt_level_ == 0) return std::move(groups_);
791 // get post dominator tree
792 auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
793 // run fusion algorithm.
794 for (int phase = 0; phase < 3; ++phase) {
795 this->RunFuse(graph, post_dom_tree, phase);
796 }
797 return std::move(groups_);
798 }
799
800 class FuseMutator : private ExprMutator {
801 public:
802 // Run the transform
Transform(const Expr & body,int fuse_opt_level,size_t max_fuse_depth)803 Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) {
804 // setup the group map.
805 auto graph = IndexedForwardGraph::Create(&arena_, body);
806 auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph);
807 for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
808 CHECK(graph.post_dfs_order[nid]->ref != nullptr);
809 gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
810 }
811 // The following line can be used for debug.
812 // this->DebugDumpGroup(body);
813 return this->Mutate(body);
814 }
815
816 private:
817 /*! \brief Temporary information from each group. */
818 struct GroupInfo {
819 public:
820 // The parameters of the function.
821 Array<Var> params;
822 // The arguments to call the functions.
823 Array<Expr> arguments;
824 // Get a new parameter or allocate an old one
GetOrAllocParamtvm::relay::FuseMutator::GroupInfo825 Var GetOrAllocParam(const Expr& expr, const Type& type) {
826 // run linear scan as most fused groups contain only a few inputs.
827 for (size_t i = 0; i < arguments.size(); ++i) {
828 if (expr.same_as(arguments[i])) return params[i];
829 }
830 // create a new parameter.
831 std::ostringstream os;
832 os << "p" << params.size();
833 auto var = Var(os.str(), type);
834 params.push_back(var);
835 arguments.push_back(expr);
836 return var;
837 }
838 };
839 /*! \brief Internal arena. */
840 support::Arena arena_;
841 /*! \brief The group assignment map. */
842 std::unordered_map<const Object*, GraphPartitioner::Group*> gmap_;
843 /* \brief Internal group information map. */
844 std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_;
845
846 // Skip primitive function.
VisitExpr_(const FunctionNode * fn_node)847 Expr VisitExpr_(const FunctionNode* fn_node) {
848 if (fn_node->HasNonzeroAttr(attr::kPrimitive)) {
849 return GetRef<Expr>(fn_node);
850 } else {
851 return ExprMutator::VisitExpr_(fn_node);
852 }
853 }
854
855 // Transform calls.
VisitExpr_(const CallNode * call)856 Expr VisitExpr_(const CallNode* call) {
857 if (call->op.as<OpNode>()) {
858 static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
859
860 if (fnoncomputational.get(Downcast<Op>(call->op), false)) {
861 return ExprMutator::VisitExpr_(call);
862 }
863
864 // If it is a primitive op call
865 // then we must have a group assignment for it already.
866 CHECK(gmap_.count(call));
867 if (call->op == stop_fusion_op) {
868 return ExprMutator::VisitExpr(call->args[0]);
869 }
870 auto* ret_group = gmap_.at(call)->FindRoot();
871 Array<Expr> new_args = GetNewArguments(call->args, ret_group);
872
873 auto new_call = Call(call->op, new_args, call->attrs, call->type_args);
874
875 if (ret_group->root_ref == call) {
876 // This is the root of the group
877 // create the new call node.
878 return MakeNewFunction(ret_group, call->checked_type(), new_call);
879 } else {
880 // This is an intermediate node of a fused function
881 // simply return the new call.
882 return std::move(new_call);
883 }
884 } else {
885 return ExprMutator::VisitExpr_(call);
886 }
887 }
888
VisitExpr_(const TupleNode * tuple)889 Expr VisitExpr_(const TupleNode* tuple) {
890 auto* ret_group = gmap_.at(tuple)->FindRoot();
891 if (ret_group->root_ref == tuple) {
892 return ExprMutator::VisitExpr_(tuple);
893 }
894 // This tuple is an intermediate node in the group
895 Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
896 return Tuple(new_fields);
897 }
898
VisitExpr_(const TupleGetItemNode * tuple_get)899 Expr VisitExpr_(const TupleGetItemNode* tuple_get) {
900 auto* ret_group = gmap_.at(tuple_get)->FindRoot();
901 auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
902 auto new_node = TupleGetItem(new_tuple, tuple_get->index);
903 if (ret_group->root_ref == tuple_get) {
904 if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
905 // Isolated. This case occurs when tuple is created by an Opaque op
906 // e.g. multibox_transform_loc
907 return ExprMutator::VisitExpr_(tuple_get);
908 }
909 // A new function whose output is a tuple field access
910 return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node);
911 }
912 // This is an intermediate node in the group
913 return std::move(new_node);
914 }
915
MakeNewFunction(GraphPartitioner::Group * group,Type ret_type,Expr body)916 Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
917 // If the function has no call, it is not a primitive function.
918 struct HasCallVisitor : ExprVisitor {
919 bool has_call = false;
920 void VisitExpr_(const CallNode* op) final { has_call = true; }
921 } visitor;
922 visitor(body);
923 const GroupInfo& ginfo = ginfo_[group];
924 auto func = Function(ginfo.params, body, ret_type, {});
925 func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call));
926 return Call(func, ginfo.arguments, Attrs());
927 }
928
GetNewArguments(const tvm::Array<Expr> & args,GraphPartitioner::Group * current_group)929 Array<Expr> GetNewArguments(const tvm::Array<Expr>& args,
930 GraphPartitioner::Group* current_group) {
931 Array<Expr> new_args;
932 for (auto arg : args) {
933 auto* arg_group = gmap_.at(arg.get())->FindRoot();
934 auto type = arg->checked_type();
935 Expr new_arg = this->Mutate(arg);
936 if (current_group != arg_group) {
937 Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
938 new_args.push_back(param);
939 } else {
940 new_args.push_back(new_arg);
941 }
942 }
943 return new_args;
944 }
945
946 // Debug function, dump the group assignment in text.
DebugDumpGroup(const Expr & body)947 void DebugDumpGroup(const Expr& body) {
948 std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string {
949 auto it = gmap_.find(expr.get());
950 if (it == gmap_.end()) return "";
951 std::ostringstream os;
952 auto* group = it->second->FindRoot();
953 os << " /* group=" << group << " */";
954 return os.str();
955 });
956 LOG(INFO) << "Dump of group info:\n" << text;
957 }
958 };
959
FuseOps(const Expr & expr,int fuse_opt_level,size_t max_fuse_depth,const IRModule & module)960 Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, const IRModule& module) {
961 return FuseMutator().Transform(expr, fuse_opt_level, max_fuse_depth);
962 }
963
964 namespace transform {
965
FuseOps(int fuse_opt_level)966 Pass FuseOps(int fuse_opt_level) {
967 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
968 [=](Function f, IRModule m, PassContext pc) {
969 int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
970 auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps));
971 return Downcast<Function>(FuseOps(f, opt_level, max_fuse_depth.value(), m));
972 };
973 return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
974 }
975
976 TVM_REGISTER_GLOBAL("relay._transform.FuseOps").set_body_typed(FuseOps);
977
978 } // namespace transform
979
980 } // namespace relay
981 } // namespace tvm
982