1 //===- OptUtils.cpp - MLIR Execution Engine optimization pass utilities ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the utility functions to trigger LLVM optimizations from
10 // MLIR Execution Engine.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/ExecutionEngine/OptUtils.h"
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/IR/LegacyPassManager.h"
19 #include "llvm/IR/LegacyPassNameParser.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/InitializePasses.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Support/Allocator.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Error.h"
26 #include "llvm/Support/StringSaver.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/Transforms/IPO.h"
29 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
30 #include <climits>
31 #include <mutex>
32
33 // Run the module and function passes managed by the module manager.
runPasses(llvm::legacy::PassManager & modulePM,llvm::legacy::FunctionPassManager & funcPM,llvm::Module & m)34 static void runPasses(llvm::legacy::PassManager &modulePM,
35 llvm::legacy::FunctionPassManager &funcPM,
36 llvm::Module &m) {
37 funcPM.doInitialization();
38 for (auto &func : m) {
39 funcPM.run(func);
40 }
41 funcPM.doFinalization();
42 modulePM.run(m);
43 }
44
45 // Initialize basic LLVM transformation passes under lock.
initializeLLVMPasses()46 void mlir::initializeLLVMPasses() {
47 static std::mutex mutex;
48 std::lock_guard<std::mutex> lock(mutex);
49
50 auto ®istry = *llvm::PassRegistry::getPassRegistry();
51 llvm::initializeCore(registry);
52 llvm::initializeTransformUtils(registry);
53 llvm::initializeScalarOpts(registry);
54 llvm::initializeIPO(registry);
55 llvm::initializeInstCombine(registry);
56 llvm::initializeAggressiveInstCombine(registry);
57 llvm::initializeAnalysis(registry);
58 llvm::initializeVectorization(registry);
59 }
60
61 // Populate pass managers according to the optimization and size levels.
62 // This behaves similarly to LLVM opt.
populatePassManagers(llvm::legacy::PassManager & modulePM,llvm::legacy::FunctionPassManager & funcPM,unsigned optLevel,unsigned sizeLevel,llvm::TargetMachine * targetMachine)63 static void populatePassManagers(llvm::legacy::PassManager &modulePM,
64 llvm::legacy::FunctionPassManager &funcPM,
65 unsigned optLevel, unsigned sizeLevel,
66 llvm::TargetMachine *targetMachine) {
67 llvm::PassManagerBuilder builder;
68 builder.OptLevel = optLevel;
69 builder.SizeLevel = sizeLevel;
70 builder.Inliner = llvm::createFunctionInliningPass(
71 optLevel, sizeLevel, /*DisableInlineHotCallSite=*/false);
72 builder.LoopVectorize = optLevel > 1 && sizeLevel < 2;
73 builder.SLPVectorize = optLevel > 1 && sizeLevel < 2;
74 builder.DisableUnrollLoops = (optLevel == 0);
75
76 if (targetMachine) {
77 // Add pass to initialize TTI for this specific target. Otherwise, TTI will
78 // be initialized to NoTTIImpl by default.
79 modulePM.add(createTargetTransformInfoWrapperPass(
80 targetMachine->getTargetIRAnalysis()));
81 funcPM.add(createTargetTransformInfoWrapperPass(
82 targetMachine->getTargetIRAnalysis()));
83 }
84
85 builder.populateModulePassManager(modulePM);
86 builder.populateFunctionPassManager(funcPM);
87 }
88
89 // Create and return a lambda that uses LLVM pass manager builder to set up
90 // optimizations based on the given level.
91 std::function<llvm::Error(llvm::Module *)>
makeOptimizingTransformer(unsigned optLevel,unsigned sizeLevel,llvm::TargetMachine * targetMachine)92 mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel,
93 llvm::TargetMachine *targetMachine) {
94 return [optLevel, sizeLevel, targetMachine](llvm::Module *m) -> llvm::Error {
95 llvm::legacy::PassManager modulePM;
96 llvm::legacy::FunctionPassManager funcPM(m);
97 populatePassManagers(modulePM, funcPM, optLevel, sizeLevel, targetMachine);
98 runPasses(modulePM, funcPM, *m);
99
100 return llvm::Error::success();
101 };
102 }
103
104 // Create and return a lambda that is given a set of passes to run, plus an
105 // optional optimization level to pre-populate the pass manager.
makeLLVMPassesTransformer(llvm::ArrayRef<const llvm::PassInfo * > llvmPasses,llvm::Optional<unsigned> mbOptLevel,llvm::TargetMachine * targetMachine,unsigned optPassesInsertPos)106 std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
107 llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
108 llvm::Optional<unsigned> mbOptLevel, llvm::TargetMachine *targetMachine,
109 unsigned optPassesInsertPos) {
110 return [llvmPasses, mbOptLevel, optPassesInsertPos,
111 targetMachine](llvm::Module *m) -> llvm::Error {
112 llvm::legacy::PassManager modulePM;
113 llvm::legacy::FunctionPassManager funcPM(m);
114
115 bool insertOptPasses = mbOptLevel.hasValue();
116 for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
117 const auto *passInfo = llvmPasses[i];
118 if (!passInfo->getNormalCtor())
119 continue;
120
121 if (insertOptPasses && optPassesInsertPos == i) {
122 populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
123 targetMachine);
124 insertOptPasses = false;
125 }
126
127 auto *pass = passInfo->createPass();
128 if (!pass)
129 return llvm::make_error<llvm::StringError>(
130 "could not create pass " + passInfo->getPassName(),
131 llvm::inconvertibleErrorCode());
132 modulePM.add(pass);
133 }
134
135 if (insertOptPasses)
136 populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
137 targetMachine);
138
139 runPasses(modulePM, funcPM, *m);
140 return llvm::Error::success();
141 };
142 }
143