1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "PassDetail.h"
17 #include "mlir/Analysis/CallGraph.h"
18 #include "mlir/IR/Threading.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Pass/PassManager.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 #include "mlir/Transforms/Passes.h"
23 #include "llvm/ADT/SCCIterator.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/Parallel.h"
26
27 #define DEBUG_TYPE "inlining"
28
29 using namespace mlir;
30
31 /// This function implements the default inliner optimization pipeline.
defaultInlinerOptPipeline(OpPassManager & pm)32 static void defaultInlinerOptPipeline(OpPassManager &pm) {
33 pm.addPass(createCanonicalizerPass());
34 }
35
36 //===----------------------------------------------------------------------===//
37 // Symbol Use Tracking
38 //===----------------------------------------------------------------------===//
39
40 /// Walk all of the used symbol callgraph nodes referenced with the given op.
walkReferencedSymbolNodes(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable,DenseMap<Attribute,CallGraphNode * > & resolvedRefs,function_ref<void (CallGraphNode *,Operation *)> callback)41 static void walkReferencedSymbolNodes(
42 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
43 DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
44 function_ref<void(CallGraphNode *, Operation *)> callback) {
45 auto symbolUses = SymbolTable::getSymbolUses(op);
46 assert(symbolUses && "expected uses to be valid");
47
48 Operation *symbolTableOp = op->getParentOp();
49 for (const SymbolTable::SymbolUse &use : *symbolUses) {
50 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
51 CallGraphNode *&node = refIt.first->second;
52
53 // If this is the first instance of this reference, try to resolve a
54 // callgraph node for it.
55 if (refIt.second) {
56 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
57 use.getSymbolRef());
58 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
59 if (!callableOp)
60 continue;
61 node = cg.lookupNode(callableOp.getCallableRegion());
62 }
63 if (node)
64 callback(node, use.getUser());
65 }
66 }
67
68 //===----------------------------------------------------------------------===//
69 // CGUseList
70
71 namespace {
72 /// This struct tracks the uses of callgraph nodes that can be dropped when
73 /// use_empty. It directly tracks and manages a use-list for all of the
74 /// call-graph nodes. This is necessary because many callgraph nodes are
75 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
76 /// class.
77 struct CGUseList {
78 /// This struct tracks the uses of callgraph nodes within a specific
79 /// operation.
80 struct CGUser {
81 /// Any nodes referenced in the top-level attribute list of this user. We
82 /// use a set here because the number of references does not matter.
83 DenseSet<CallGraphNode *> topLevelUses;
84
85 /// Uses of nodes referenced by nested operations.
86 DenseMap<CallGraphNode *, int> innerUses;
87 };
88
89 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
90
91 /// Drop uses of nodes referred to by the given call operation that resides
92 /// within 'userNode'.
93 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
94
95 /// Remove the given node from the use list.
96 void eraseNode(CallGraphNode *node);
97
98 /// Returns true if the given callgraph node has no uses and can be pruned.
99 bool isDead(CallGraphNode *node) const;
100
101 /// Returns true if the given callgraph node has a single use and can be
102 /// discarded.
103 bool hasOneUseAndDiscardable(CallGraphNode *node) const;
104
105 /// Recompute the uses held by the given callgraph node.
106 void recomputeUses(CallGraphNode *node, CallGraph &cg);
107
108 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
109 /// of 'lhs' into 'rhs'.
110 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
111
112 private:
113 /// Decrement the uses of discardable nodes referenced by the given user.
114 void decrementDiscardableUses(CGUser &uses);
115
116 /// A mapping between a discardable callgraph node (that is a symbol) and the
117 /// number of uses for this node.
118 DenseMap<CallGraphNode *, int> discardableSymNodeUses;
119
120 /// A mapping between a callgraph node and the symbol callgraph nodes that it
121 /// uses.
122 DenseMap<CallGraphNode *, CGUser> nodeUses;
123
124 /// A symbol table to use when resolving call lookups.
125 SymbolTableCollection &symbolTable;
126 };
127 } // end anonymous namespace
128
CGUseList(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable)129 CGUseList::CGUseList(Operation *op, CallGraph &cg,
130 SymbolTableCollection &symbolTable)
131 : symbolTable(symbolTable) {
132 /// A set of callgraph nodes that are always known to be live during inlining.
133 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
134
135 // Walk each of the symbol tables looking for discardable callgraph nodes.
136 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
137 for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
138 // If this is a callgraph operation, check to see if it is discardable.
139 if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
140 if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
141 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
142 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
143 symbol.canDiscardOnUseEmpty()) {
144 discardableSymNodeUses.try_emplace(node, 0);
145 }
146 continue;
147 }
148 }
149 // Otherwise, check for any referenced nodes. These will be always-live.
150 walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
151 [](CallGraphNode *, Operation *) {});
152 }
153 };
154 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
155 walkFn);
156
157 // Drop the use information for any discardable nodes that are always live.
158 for (auto &it : alwaysLiveNodes)
159 discardableSymNodeUses.erase(it.second);
160
161 // Compute the uses for each of the callable nodes in the graph.
162 for (CallGraphNode *node : cg)
163 recomputeUses(node, cg);
164 }
165
dropCallUses(CallGraphNode * userNode,Operation * callOp,CallGraph & cg)166 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
167 CallGraph &cg) {
168 auto &userRefs = nodeUses[userNode].innerUses;
169 auto walkFn = [&](CallGraphNode *node, Operation *user) {
170 auto parentIt = userRefs.find(node);
171 if (parentIt == userRefs.end())
172 return;
173 --parentIt->second;
174 --discardableSymNodeUses[node];
175 };
176 DenseMap<Attribute, CallGraphNode *> resolvedRefs;
177 walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
178 }
179
eraseNode(CallGraphNode * node)180 void CGUseList::eraseNode(CallGraphNode *node) {
181 // Drop all child nodes.
182 for (auto &edge : *node)
183 if (edge.isChild())
184 eraseNode(edge.getTarget());
185
186 // Drop the uses held by this node and erase it.
187 auto useIt = nodeUses.find(node);
188 assert(useIt != nodeUses.end() && "expected node to be valid");
189 decrementDiscardableUses(useIt->getSecond());
190 nodeUses.erase(useIt);
191 discardableSymNodeUses.erase(node);
192 }
193
isDead(CallGraphNode * node) const194 bool CGUseList::isDead(CallGraphNode *node) const {
195 // If the parent operation isn't a symbol, simply check normal SSA deadness.
196 Operation *nodeOp = node->getCallableRegion()->getParentOp();
197 if (!isa<SymbolOpInterface>(nodeOp))
198 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
199
200 // Otherwise, check the number of symbol uses.
201 auto symbolIt = discardableSymNodeUses.find(node);
202 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
203 }
204
hasOneUseAndDiscardable(CallGraphNode * node) const205 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
206 // If this isn't a symbol node, check for side-effects and SSA use count.
207 Operation *nodeOp = node->getCallableRegion()->getParentOp();
208 if (!isa<SymbolOpInterface>(nodeOp))
209 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
210
211 // Otherwise, check the number of symbol uses.
212 auto symbolIt = discardableSymNodeUses.find(node);
213 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
214 }
215
recomputeUses(CallGraphNode * node,CallGraph & cg)216 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
217 Operation *parentOp = node->getCallableRegion()->getParentOp();
218 CGUser &uses = nodeUses[node];
219 decrementDiscardableUses(uses);
220
221 // Collect the new discardable uses within this node.
222 uses = CGUser();
223 DenseMap<Attribute, CallGraphNode *> resolvedRefs;
224 auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
225 auto discardSymIt = discardableSymNodeUses.find(refNode);
226 if (discardSymIt == discardableSymNodeUses.end())
227 return;
228
229 if (user != parentOp)
230 ++uses.innerUses[refNode];
231 else if (!uses.topLevelUses.insert(refNode).second)
232 return;
233 ++discardSymIt->second;
234 };
235 walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
236 }
237
mergeUsesAfterInlining(CallGraphNode * lhs,CallGraphNode * rhs)238 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
239 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
240 for (auto &useIt : lhsUses.innerUses) {
241 rhsUses.innerUses[useIt.first] += useIt.second;
242 discardableSymNodeUses[useIt.first] += useIt.second;
243 }
244 }
245
decrementDiscardableUses(CGUser & uses)246 void CGUseList::decrementDiscardableUses(CGUser &uses) {
247 for (CallGraphNode *node : uses.topLevelUses)
248 --discardableSymNodeUses[node];
249 for (auto &it : uses.innerUses)
250 discardableSymNodeUses[it.first] -= it.second;
251 }
252
253 //===----------------------------------------------------------------------===//
254 // CallGraph traversal
255 //===----------------------------------------------------------------------===//
256
257 namespace {
258 /// This class represents a specific callgraph SCC.
259 class CallGraphSCC {
260 public:
CallGraphSCC(llvm::scc_iterator<const CallGraph * > & parentIterator)261 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
262 : parentIterator(parentIterator) {}
263 /// Return a range over the nodes within this SCC.
begin()264 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
end()265 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
266
267 /// Reset the nodes of this SCC with those provided.
reset(const std::vector<CallGraphNode * > & newNodes)268 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
269
270 /// Remove the given node from this SCC.
remove(CallGraphNode * node)271 void remove(CallGraphNode *node) {
272 auto it = llvm::find(nodes, node);
273 if (it != nodes.end()) {
274 nodes.erase(it);
275 parentIterator.ReplaceNode(node, nullptr);
276 }
277 }
278
279 private:
280 std::vector<CallGraphNode *> nodes;
281 llvm::scc_iterator<const CallGraph *> &parentIterator;
282 };
283 } // end anonymous namespace
284
285 /// Run a given transformation over the SCCs of the callgraph in a bottom up
286 /// traversal.
runTransformOnCGSCCs(const CallGraph & cg,function_ref<LogicalResult (CallGraphSCC &)> sccTransformer)287 static LogicalResult runTransformOnCGSCCs(
288 const CallGraph &cg,
289 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
290 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
291 CallGraphSCC currentSCC(cgi);
292 while (!cgi.isAtEnd()) {
293 // Copy the current SCC and increment so that the transformer can modify the
294 // SCC without invalidating our iterator.
295 currentSCC.reset(*cgi);
296 ++cgi;
297 if (failed(sccTransformer(currentSCC)))
298 return failure();
299 }
300 return success();
301 }
302
303 namespace {
304 /// This struct represents a resolved call to a given callgraph node. Given that
305 /// the call does not actually contain a direct reference to the
306 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
307 /// explicitly.
308 struct ResolvedCall {
ResolvedCall__anon7742a5a60711::ResolvedCall309 ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
310 CallGraphNode *targetNode)
311 : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
312 CallOpInterface call;
313 CallGraphNode *sourceNode, *targetNode;
314 };
315 } // end anonymous namespace
316
317 /// Collect all of the callable operations within the given range of blocks. If
318 /// `traverseNestedCGNodes` is true, this will also collect call operations
319 /// inside of nested callgraph nodes.
collectCallOps(iterator_range<Region::iterator> blocks,CallGraphNode * sourceNode,CallGraph & cg,SymbolTableCollection & symbolTable,SmallVectorImpl<ResolvedCall> & calls,bool traverseNestedCGNodes)320 static void collectCallOps(iterator_range<Region::iterator> blocks,
321 CallGraphNode *sourceNode, CallGraph &cg,
322 SymbolTableCollection &symbolTable,
323 SmallVectorImpl<ResolvedCall> &calls,
324 bool traverseNestedCGNodes) {
325 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
326 auto addToWorklist = [&](CallGraphNode *node,
327 iterator_range<Region::iterator> blocks) {
328 for (Block &block : blocks)
329 worklist.emplace_back(&block, node);
330 };
331
332 addToWorklist(sourceNode, blocks);
333 while (!worklist.empty()) {
334 Block *block;
335 std::tie(block, sourceNode) = worklist.pop_back_val();
336
337 for (Operation &op : *block) {
338 if (auto call = dyn_cast<CallOpInterface>(op)) {
339 // TODO: Support inlining nested call references.
340 CallInterfaceCallable callable = call.getCallableForCallee();
341 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
342 if (!symRef.isa<FlatSymbolRefAttr>())
343 continue;
344 }
345
346 CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
347 if (!targetNode->isExternal())
348 calls.emplace_back(call, sourceNode, targetNode);
349 continue;
350 }
351
352 // If this is not a call, traverse the nested regions. If
353 // `traverseNestedCGNodes` is false, then don't traverse nested call graph
354 // regions.
355 for (auto &nestedRegion : op.getRegions()) {
356 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
357 if (traverseNestedCGNodes || !nestedNode)
358 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
359 }
360 }
361 }
362 }
363
364 //===----------------------------------------------------------------------===//
365 // Inliner
366 //===----------------------------------------------------------------------===//
367 namespace {
368 /// This class provides a specialization of the main inlining interface.
369 struct Inliner : public InlinerInterface {
Inliner__anon7742a5a60911::Inliner370 Inliner(MLIRContext *context, CallGraph &cg,
371 SymbolTableCollection &symbolTable)
372 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
373
374 /// Process a set of blocks that have been inlined. This callback is invoked
375 /// *before* inlined terminator operations have been processed.
376 void
processInlinedBlocks__anon7742a5a60911::Inliner377 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
378 // Find the closest callgraph node from the first block.
379 CallGraphNode *node;
380 Region *region = inlinedBlocks.begin()->getParent();
381 while (!(node = cg.lookupNode(region))) {
382 region = region->getParentRegion();
383 assert(region && "expected valid parent node");
384 }
385
386 collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
387 /*traverseNestedCGNodes=*/true);
388 }
389
390 /// Mark the given callgraph node for deletion.
markForDeletion__anon7742a5a60911::Inliner391 void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
392
393 /// This method properly disposes of callables that became dead during
394 /// inlining. This should not be called while iterating over the SCCs.
eraseDeadCallables__anon7742a5a60911::Inliner395 void eraseDeadCallables() {
396 for (CallGraphNode *node : deadNodes)
397 node->getCallableRegion()->getParentOp()->erase();
398 }
399
400 /// The set of callables known to be dead.
401 SmallPtrSet<CallGraphNode *, 8> deadNodes;
402
403 /// The current set of call instructions to consider for inlining.
404 SmallVector<ResolvedCall, 8> calls;
405
406 /// The callgraph being operated on.
407 CallGraph &cg;
408
409 /// A symbol table to use when resolving call lookups.
410 SymbolTableCollection &symbolTable;
411 };
412 } // namespace
413
414 /// Returns true if the given call should be inlined.
shouldInline(ResolvedCall & resolvedCall)415 static bool shouldInline(ResolvedCall &resolvedCall) {
416 // Don't allow inlining terminator calls. We currently don't support this
417 // case.
418 if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
419 return false;
420
421 // Don't allow inlining if the target is an ancestor of the call. This
422 // prevents inlining recursively.
423 if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
424 resolvedCall.call->getParentRegion()))
425 return false;
426
427 // Otherwise, inline.
428 return true;
429 }
430
431 /// Attempt to inline calls within the given scc. This function returns
432 /// success if any calls were inlined, failure otherwise.
inlineCallsInSCC(Inliner & inliner,CGUseList & useList,CallGraphSCC & currentSCC)433 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
434 CallGraphSCC ¤tSCC) {
435 CallGraph &cg = inliner.cg;
436 auto &calls = inliner.calls;
437
438 // A set of dead nodes to remove after inlining.
439 SmallVector<CallGraphNode *, 1> deadNodes;
440
441 // Collect all of the direct calls within the nodes of the current SCC. We
442 // don't traverse nested callgraph nodes, because they are handled separately
443 // likely within a different SCC.
444 for (CallGraphNode *node : currentSCC) {
445 if (node->isExternal())
446 continue;
447
448 // Don't collect calls if the node is already dead.
449 if (useList.isDead(node)) {
450 deadNodes.push_back(node);
451 } else {
452 collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
453 calls, /*traverseNestedCGNodes=*/false);
454 }
455 }
456
457 // Try to inline each of the call operations. Don't cache the end iterator
458 // here as more calls may be added during inlining.
459 bool inlinedAnyCalls = false;
460 for (unsigned i = 0; i != calls.size(); ++i) {
461 ResolvedCall it = calls[i];
462 bool doInline = shouldInline(it);
463 CallOpInterface call = it.call;
464 LLVM_DEBUG({
465 if (doInline)
466 llvm::dbgs() << "* Inlining call: " << call << "\n";
467 else
468 llvm::dbgs() << "* Not inlining call: " << call << "\n";
469 });
470 if (!doInline)
471 continue;
472 Region *targetRegion = it.targetNode->getCallableRegion();
473
474 // If this is the last call to the target node and the node is discardable,
475 // then inline it in-place and delete the node if successful.
476 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
477
478 LogicalResult inlineResult = inlineCall(
479 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
480 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
481 if (failed(inlineResult)) {
482 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
483 continue;
484 }
485 inlinedAnyCalls = true;
486
487 // If the inlining was successful, Merge the new uses into the source node.
488 useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
489 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
490
491 // then erase the call.
492 call.erase();
493
494 // If we inlined in place, mark the node for deletion.
495 if (inlineInPlace) {
496 useList.eraseNode(it.targetNode);
497 deadNodes.push_back(it.targetNode);
498 }
499 }
500
501 for (CallGraphNode *node : deadNodes) {
502 currentSCC.remove(node);
503 inliner.markForDeletion(node);
504 }
505 calls.clear();
506 return success(inlinedAnyCalls);
507 }
508
509 //===----------------------------------------------------------------------===//
510 // InlinerPass
511 //===----------------------------------------------------------------------===//
512
513 namespace {
514 class InlinerPass : public InlinerBase<InlinerPass> {
515 public:
516 InlinerPass();
517 InlinerPass(const InlinerPass &) = default;
518 InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
519 InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
520 llvm::StringMap<OpPassManager> opPipelines);
521 void runOnOperation() override;
522
523 private:
524 /// Attempt to inline calls within the given scc, and run simplifications,
525 /// until a fixed point is reached. This allows for the inlining of newly
526 /// devirtualized calls. Returns failure if there was a fatal error during
527 /// inlining.
528 LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
529 CallGraphSCC ¤tSCC, MLIRContext *context);
530
531 /// Optimize the nodes within the given SCC with one of the held optimization
532 /// pass pipelines. Returns failure if an error occurred during the
533 /// optimization of the SCC, success otherwise.
534 LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
535 CallGraphSCC ¤tSCC, MLIRContext *context);
536
537 /// Optimize the nodes within the given SCC in parallel. Returns failure if an
538 /// error occurred during the optimization of the SCC, success otherwise.
539 LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
540 MLIRContext *context);
541
542 /// Optimize the given callable node with one of the pass managers provided
543 /// with `pipelines`, or the default pipeline. Returns failure if an error
544 /// occurred during the optimization of the callable, success otherwise.
545 LogicalResult optimizeCallable(CallGraphNode *node,
546 llvm::StringMap<OpPassManager> &pipelines);
547
548 /// Attempt to initialize the options of this pass from the given string.
549 /// Derived classes may override this method to hook into the point at which
550 /// options are initialized, but should generally always invoke this base
551 /// class variant.
552 LogicalResult initializeOptions(StringRef options) override;
553
554 /// An optional function that constructs a default optimization pipeline for
555 /// a given operation.
556 std::function<void(OpPassManager &)> defaultPipeline;
557 /// A map of operation names to pass pipelines to use when optimizing
558 /// callable operations of these types. This provides a specialized pipeline
559 /// instead of the default. The vector size is the number of threads used
560 /// during optimization.
561 SmallVector<llvm::StringMap<OpPassManager>, 8> opPipelines;
562 };
563 } // end anonymous namespace
564
InlinerPass()565 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
InlinerPass(std::function<void (OpPassManager &)> defaultPipeline)566 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline)
567 : defaultPipeline(defaultPipeline) {
568 opPipelines.push_back({});
569
570 // Initialize the pass options with the provided arguments.
571 if (defaultPipeline) {
572 OpPassManager fakePM("__mlir_fake_pm_op");
573 defaultPipeline(fakePM);
574 llvm::raw_string_ostream strStream(defaultPipelineStr);
575 fakePM.printAsTextualPipeline(strStream);
576 }
577 }
578
InlinerPass(std::function<void (OpPassManager &)> defaultPipeline,llvm::StringMap<OpPassManager> opPipelines)579 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
580 llvm::StringMap<OpPassManager> opPipelines)
581 : InlinerPass(std::move(defaultPipeline)) {
582 if (opPipelines.empty())
583 return;
584
585 // Update the option for the op specific optimization pipelines.
586 for (auto &it : opPipelines) {
587 std::string pipeline;
588 llvm::raw_string_ostream pipelineOS(pipeline);
589 pipelineOS << it.getKey() << "(";
590 it.second.printAsTextualPipeline(pipelineOS);
591 pipelineOS << ")";
592 opPipelineStrs.addValue(pipeline);
593 }
594 this->opPipelines.emplace_back(std::move(opPipelines));
595 }
596
runOnOperation()597 void InlinerPass::runOnOperation() {
598 CallGraph &cg = getAnalysis<CallGraph>();
599 auto *context = &getContext();
600
601 // The inliner should only be run on operations that define a symbol table,
602 // as the callgraph will need to resolve references.
603 Operation *op = getOperation();
604 if (!op->hasTrait<OpTrait::SymbolTable>()) {
605 op->emitOpError() << " was scheduled to run under the inliner, but does "
606 "not define a symbol table";
607 return signalPassFailure();
608 }
609
610 // Run the inline transform in post-order over the SCCs in the callgraph.
611 SymbolTableCollection symbolTable;
612 Inliner inliner(context, cg, symbolTable);
613 CGUseList useList(getOperation(), cg, symbolTable);
614 LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
615 return inlineSCC(inliner, useList, scc, context);
616 });
617 if (failed(result))
618 return signalPassFailure();
619
620 // After inlining, make sure to erase any callables proven to be dead.
621 inliner.eraseDeadCallables();
622 }
623
inlineSCC(Inliner & inliner,CGUseList & useList,CallGraphSCC & currentSCC,MLIRContext * context)624 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
625 CallGraphSCC ¤tSCC,
626 MLIRContext *context) {
627 // Continuously simplify and inline until we either reach a fixed point, or
628 // hit the maximum iteration count. Simplifying early helps to refine the cost
629 // model, and in future iterations may devirtualize new calls.
630 unsigned iterationCount = 0;
631 do {
632 if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
633 return failure();
634 if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
635 break;
636 } while (++iterationCount < maxInliningIterations);
637 return success();
638 }
639
optimizeSCC(CallGraph & cg,CGUseList & useList,CallGraphSCC & currentSCC,MLIRContext * context)640 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
641 CallGraphSCC ¤tSCC,
642 MLIRContext *context) {
643 // Collect the sets of nodes to simplify.
644 SmallVector<CallGraphNode *, 4> nodesToVisit;
645 for (auto *node : currentSCC) {
646 if (node->isExternal())
647 continue;
648
649 // Don't simplify nodes with children. Nodes with children require special
650 // handling as we may remove the node during simplification. In the future,
651 // we should be able to handle this case with proper node deletion tracking.
652 if (node->hasChildren())
653 continue;
654
655 // We also won't apply simplifications to nodes that can't have passes
656 // scheduled on them.
657 auto *region = node->getCallableRegion();
658 if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
659 continue;
660 nodesToVisit.push_back(node);
661 }
662 if (nodesToVisit.empty())
663 return success();
664
665 // Optimize each of the nodes within the SCC in parallel.
666 if (failed(optimizeSCCAsync(nodesToVisit, context)))
667 return failure();
668
669 // Recompute the uses held by each of the nodes.
670 for (CallGraphNode *node : nodesToVisit)
671 useList.recomputeUses(node, cg);
672 return success();
673 }
674
675 LogicalResult
optimizeSCCAsync(MutableArrayRef<CallGraphNode * > nodesToVisit,MLIRContext * ctx)676 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
677 MLIRContext *ctx) {
678 // Ensure that there are enough pipeline maps for the optimizer to run in
679 // parallel. Note: The number of pass managers here needs to remain constant
680 // to prevent issues with pass instrumentations that rely on having the same
681 // pass manager for the main thread.
682 size_t numThreads = llvm::hardware_concurrency().compute_thread_count();
683 if (opPipelines.size() < numThreads) {
684 // Reserve before resizing so that we can use a reference to the first
685 // element.
686 opPipelines.reserve(numThreads);
687 opPipelines.resize(numThreads, opPipelines.front());
688 }
689
690 // Ensure an analysis manager has been constructed for each of the nodes.
691 // This prevents thread races when running the nested pipelines.
692 for (CallGraphNode *node : nodesToVisit)
693 getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
694
695 // An atomic failure variable for the async executors.
696 std::vector<std::atomic<bool>> activePMs(opPipelines.size());
697 std::fill(activePMs.begin(), activePMs.end(), false);
698 return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
699 // Find a pass manager for this operation.
700 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
701 bool expectedInactive = false;
702 return isActive.compare_exchange_strong(expectedInactive, true);
703 });
704 unsigned pmIndex = it - activePMs.begin();
705
706 // Optimize this callable node.
707 LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
708
709 // Reset the active bit for this pass manager.
710 activePMs[pmIndex].store(false);
711 return result;
712 });
713 }
714
715 LogicalResult
optimizeCallable(CallGraphNode * node,llvm::StringMap<OpPassManager> & pipelines)716 InlinerPass::optimizeCallable(CallGraphNode *node,
717 llvm::StringMap<OpPassManager> &pipelines) {
718 Operation *callable = node->getCallableRegion()->getParentOp();
719 StringRef opName = callable->getName().getStringRef();
720 auto pipelineIt = pipelines.find(opName);
721 if (pipelineIt == pipelines.end()) {
722 // If a pipeline didn't exist, use the default if possible.
723 if (!defaultPipeline)
724 return success();
725
726 OpPassManager defaultPM(opName);
727 defaultPipeline(defaultPM);
728 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
729 }
730 return runPipeline(pipelineIt->second, callable);
731 }
732
initializeOptions(StringRef options)733 LogicalResult InlinerPass::initializeOptions(StringRef options) {
734 if (failed(Pass::initializeOptions(options)))
735 return failure();
736
737 // Initialize the default pipeline builder to use the option string.
738 if (!defaultPipelineStr.empty()) {
739 std::string defaultPipelineCopy = defaultPipelineStr;
740 defaultPipeline = [=](OpPassManager &pm) {
741 (void)parsePassPipeline(defaultPipelineCopy, pm);
742 };
743 } else if (defaultPipelineStr.getNumOccurrences()) {
744 defaultPipeline = nullptr;
745 }
746
747 // Initialize the op specific pass pipelines.
748 llvm::StringMap<OpPassManager> pipelines;
749 for (StringRef pipeline : opPipelineStrs) {
750 // Skip empty pipelines.
751 if (pipeline.empty())
752 continue;
753
754 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
755 size_t pipelineStart = pipeline.find_first_of('(');
756 if (pipelineStart == StringRef::npos || !pipeline.consume_back(")"))
757 return failure();
758 StringRef opName = pipeline.take_front(pipelineStart);
759 OpPassManager pm(opName);
760 if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm)))
761 return failure();
762 pipelines.try_emplace(opName, std::move(pm));
763 }
764 opPipelines.assign({std::move(pipelines)});
765
766 return success();
767 }
768
createInlinerPass()769 std::unique_ptr<Pass> mlir::createInlinerPass() {
770 return std::make_unique<InlinerPass>();
771 }
772 std::unique_ptr<Pass>
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines)773 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
774 return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
775 std::move(opPipelines));
776 }
777 std::unique_ptr<Pass>
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,std::function<void (OpPassManager &)> defaultPipelineBuilder)778 createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
779 std::function<void(OpPassManager &)> defaultPipelineBuilder) {
780 return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
781 std::move(opPipelines));
782 }
783