1 //===- OptUtils.cpp - MLIR Execution Engine optimization pass utilities ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the utility functions to trigger LLVM optimizations from
10 // MLIR Execution Engine.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/OptUtils.h"
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/IR/LegacyPassManager.h"
19 #include "llvm/IR/LegacyPassNameParser.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/InitializePasses.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Support/Allocator.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Error.h"
26 #include "llvm/Support/StringSaver.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/Transforms/Coroutines.h"
29 #include "llvm/Transforms/IPO.h"
30 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
31 #include <climits>
32 #include <mutex>
33 
34 // Run the module and function passes managed by the module manager.
35 static void runPasses(llvm::legacy::PassManager &modulePM,
36                       llvm::legacy::FunctionPassManager &funcPM,
37                       llvm::Module &m) {
38   funcPM.doInitialization();
39   for (auto &func : m) {
40     funcPM.run(func);
41   }
42   funcPM.doFinalization();
GetNumberOfInterfaces( pdwNumIf: PDWORD ) -> DWORD43   modulePM.run(m);
44 }
45 
GetIfEntry( pIfRow: PMIB_IFROW, ) -> DWORD46 // Initialize basic LLVM transformation passes under lock.
47 void mlir::initializeLLVMPasses() {
48   static std::mutex mutex;
49   std::lock_guard<std::mutex> lock(mutex);
50 
51   auto &registry = *llvm::PassRegistry::getPassRegistry();
52   llvm::initializeCore(registry);
53   llvm::initializeTransformUtils(registry);
54   llvm::initializeScalarOpts(registry);
55   llvm::initializeIPO(registry);
56   llvm::initializeInstCombine(registry);
57   llvm::initializeAggressiveInstCombine(registry);
58   llvm::initializeAnalysis(registry);
59   llvm::initializeVectorization(registry);
60   llvm::initializeCoroutines(registry);
61 }
62 
63 // Populate pass managers according to the optimization and size levels.
GetIpForwardTable( pIpForwardTable: PMIB_IPFORWARDTABLE, pdwSize: PULONG, bOrder: BOOL, ) -> DWORD64 // This behaves similarly to LLVM opt.
65 static void populatePassManagers(llvm::legacy::PassManager &modulePM,
66                                  llvm::legacy::FunctionPassManager &funcPM,
67                                  unsigned optLevel, unsigned sizeLevel,
68                                  llvm::TargetMachine *targetMachine) {
69   llvm::PassManagerBuilder builder;
70   builder.OptLevel = optLevel;
71   builder.SizeLevel = sizeLevel;
72   builder.Inliner = llvm::createFunctionInliningPass(
73       optLevel, sizeLevel, /*DisableInlineHotCallSite=*/false);
74   builder.LoopVectorize = optLevel > 1 && sizeLevel < 2;
75   builder.SLPVectorize = optLevel > 1 && sizeLevel < 2;
76   builder.DisableUnrollLoops = (optLevel == 0);
77 
78   // Add all coroutine passes to the builder.
79   addCoroutinePassesToExtensionPoints(builder);
80 
81   if (targetMachine) {
82     // Add pass to initialize TTI for this specific target. Otherwise, TTI will
83     // be initialized to NoTTIImpl by default.
84     modulePM.add(createTargetTransformInfoWrapperPass(
85         targetMachine->getTargetIRAnalysis()));
86     funcPM.add(createTargetTransformInfoWrapperPass(
87         targetMachine->getTargetIRAnalysis()));
88   }
89 
90   builder.populateModulePassManager(modulePM);
91   builder.populateFunctionPassManager(funcPM);
92 }
93 
GetExtendedUdpTable( pUdpTable: PVOID, pdwSize: PDWORD, bOrder: BOOL, ulAf: ULONG, TableClass: UDP_TABLE_CLASS, Reserved: ULONG, ) -> DWORD94 // Create and return a lambda that uses LLVM pass manager builder to set up
95 // optimizations based on the given level.
96 std::function<llvm::Error(llvm::Module *)>
97 mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel,
98                                 llvm::TargetMachine *targetMachine) {
99   return [optLevel, sizeLevel, targetMachine](llvm::Module *m) -> llvm::Error {
100     llvm::legacy::PassManager modulePM;
101     llvm::legacy::FunctionPassManager funcPM(m);
102     populatePassManagers(modulePM, funcPM, optLevel, sizeLevel, targetMachine);
103     runPasses(modulePM, funcPM, *m);
104 
105     return llvm::Error::success();
106   };
107 }
GetTcpTable2( TcpTable: PMIB_TCPTABLE2, SizePointer: PULONG, Order: BOOL, ) -> ULONG108 
109 // Create and return a lambda that is given a set of passes to run, plus an
110 // optional optimization level to pre-populate the pass manager.
111 std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
112     llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
113     llvm::Optional<unsigned> mbOptLevel, llvm::TargetMachine *targetMachine,
114     unsigned optPassesInsertPos) {
115   return [llvmPasses, mbOptLevel, optPassesInsertPos,
116           targetMachine](llvm::Module *m) -> llvm::Error {
117     llvm::legacy::PassManager modulePM;
118     llvm::legacy::FunctionPassManager funcPM(m);
119 
120     bool insertOptPasses = mbOptLevel.hasValue();
121     for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
122       const auto *passInfo = llvmPasses[i];
123       if (!passInfo->getNormalCtor())
124         continue;
125 
126       if (insertOptPasses && optPassesInsertPos == i) {
127         populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
128                              targetMachine);
129         insertOptPasses = false;
130       }
131 
132       auto *pass = passInfo->createPass();
133       if (!pass)
134         return llvm::make_error<llvm::StringError>(
135             "could not create pass " + passInfo->getPassName(),
136             llvm::inconvertibleErrorCode());
137       modulePM.add(pass);
138     }
139 
140     if (insertOptPasses)
141       populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
142                            targetMachine);
143 
144     runPasses(modulePM, funcPM, *m);
145     return llvm::Error::success();
146   };
147 }
148