1 //===- OptUtils.cpp - MLIR Execution Engine optimization pass utilities ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the utility functions to trigger LLVM optimizations from
10 // MLIR Execution Engine.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/OptUtils.h"
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/IR/LegacyPassManager.h"
19 #include "llvm/IR/LegacyPassNameParser.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/InitializePasses.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Support/Allocator.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Error.h"
26 #include "llvm/Support/StringSaver.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/Transforms/IPO.h"
29 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
30 #include <climits>
31 #include <mutex>
32 
33 // Run the module and function passes managed by the module manager.
runPasses(llvm::legacy::PassManager & modulePM,llvm::legacy::FunctionPassManager & funcPM,llvm::Module & m)34 static void runPasses(llvm::legacy::PassManager &modulePM,
35                       llvm::legacy::FunctionPassManager &funcPM,
36                       llvm::Module &m) {
37   funcPM.doInitialization();
38   for (auto &func : m) {
39     funcPM.run(func);
40   }
41   funcPM.doFinalization();
42   modulePM.run(m);
43 }
44 
45 // Initialize basic LLVM transformation passes under lock.
initializeLLVMPasses()46 void mlir::initializeLLVMPasses() {
47   static std::mutex mutex;
48   std::lock_guard<std::mutex> lock(mutex);
49 
50   auto &registry = *llvm::PassRegistry::getPassRegistry();
51   llvm::initializeCore(registry);
52   llvm::initializeTransformUtils(registry);
53   llvm::initializeScalarOpts(registry);
54   llvm::initializeIPO(registry);
55   llvm::initializeInstCombine(registry);
56   llvm::initializeAggressiveInstCombine(registry);
57   llvm::initializeAnalysis(registry);
58   llvm::initializeVectorization(registry);
59 }
60 
61 // Populate pass managers according to the optimization and size levels.
62 // This behaves similarly to LLVM opt.
populatePassManagers(llvm::legacy::PassManager & modulePM,llvm::legacy::FunctionPassManager & funcPM,unsigned optLevel,unsigned sizeLevel,llvm::TargetMachine * targetMachine)63 static void populatePassManagers(llvm::legacy::PassManager &modulePM,
64                                  llvm::legacy::FunctionPassManager &funcPM,
65                                  unsigned optLevel, unsigned sizeLevel,
66                                  llvm::TargetMachine *targetMachine) {
67   llvm::PassManagerBuilder builder;
68   builder.OptLevel = optLevel;
69   builder.SizeLevel = sizeLevel;
70   builder.Inliner = llvm::createFunctionInliningPass(
71       optLevel, sizeLevel, /*DisableInlineHotCallSite=*/false);
72   builder.LoopVectorize = optLevel > 1 && sizeLevel < 2;
73   builder.SLPVectorize = optLevel > 1 && sizeLevel < 2;
74   builder.DisableUnrollLoops = (optLevel == 0);
75 
76   if (targetMachine) {
77     // Add pass to initialize TTI for this specific target. Otherwise, TTI will
78     // be initialized to NoTTIImpl by default.
79     modulePM.add(createTargetTransformInfoWrapperPass(
80         targetMachine->getTargetIRAnalysis()));
81     funcPM.add(createTargetTransformInfoWrapperPass(
82         targetMachine->getTargetIRAnalysis()));
83   }
84 
85   builder.populateModulePassManager(modulePM);
86   builder.populateFunctionPassManager(funcPM);
87 }
88 
89 // Create and return a lambda that uses LLVM pass manager builder to set up
90 // optimizations based on the given level.
91 std::function<llvm::Error(llvm::Module *)>
makeOptimizingTransformer(unsigned optLevel,unsigned sizeLevel,llvm::TargetMachine * targetMachine)92 mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel,
93                                 llvm::TargetMachine *targetMachine) {
94   return [optLevel, sizeLevel, targetMachine](llvm::Module *m) -> llvm::Error {
95     llvm::legacy::PassManager modulePM;
96     llvm::legacy::FunctionPassManager funcPM(m);
97     populatePassManagers(modulePM, funcPM, optLevel, sizeLevel, targetMachine);
98     runPasses(modulePM, funcPM, *m);
99 
100     return llvm::Error::success();
101   };
102 }
103 
104 // Create and return a lambda that is given a set of passes to run, plus an
105 // optional optimization level to pre-populate the pass manager.
makeLLVMPassesTransformer(llvm::ArrayRef<const llvm::PassInfo * > llvmPasses,llvm::Optional<unsigned> mbOptLevel,llvm::TargetMachine * targetMachine,unsigned optPassesInsertPos)106 std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
107     llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
108     llvm::Optional<unsigned> mbOptLevel, llvm::TargetMachine *targetMachine,
109     unsigned optPassesInsertPos) {
110   return [llvmPasses, mbOptLevel, optPassesInsertPos,
111           targetMachine](llvm::Module *m) -> llvm::Error {
112     llvm::legacy::PassManager modulePM;
113     llvm::legacy::FunctionPassManager funcPM(m);
114 
115     bool insertOptPasses = mbOptLevel.hasValue();
116     for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
117       const auto *passInfo = llvmPasses[i];
118       if (!passInfo->getNormalCtor())
119         continue;
120 
121       if (insertOptPasses && optPassesInsertPos == i) {
122         populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
123                              targetMachine);
124         insertOptPasses = false;
125       }
126 
127       auto *pass = passInfo->createPass();
128       if (!pass)
129         return llvm::make_error<llvm::StringError>(
130             "could not create pass " + passInfo->getPassName(),
131             llvm::inconvertibleErrorCode());
132       modulePM.add(pass);
133     }
134 
135     if (insertOptPasses)
136       populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
137                            targetMachine);
138 
139     runPasses(modulePM, funcPM, *m);
140     return llvm::Error::success();
141   };
142 }
143