1 //===- llvm-extract.cpp - LLVM function extraction utility ----------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This utility changes the input module to only contain a single function,
11 // which is primarily used for debugging transformations.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/ADT/SetVector.h"
16 #include "llvm/ADT/SmallPtrSet.h"
17 #include "llvm/Bitcode/BitcodeWriterPass.h"
18 #include "llvm/IR/DataLayout.h"
19 #include "llvm/IR/IRPrintingPasses.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/LLVMContext.h"
22 #include "llvm/IR/LegacyPassManager.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/IRReader/IRReader.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Error.h"
27 #include "llvm/Support/FileSystem.h"
28 #include "llvm/Support/InitLLVM.h"
29 #include "llvm/Support/Regex.h"
30 #include "llvm/Support/SourceMgr.h"
31 #include "llvm/Support/SystemUtils.h"
32 #include "llvm/Support/ToolOutputFile.h"
33 #include "llvm/Transforms/IPO.h"
34 #include <memory>
35 using namespace llvm;
36 
37 // InputFilename - The filename to read from.
38 static cl::opt<std::string>
39 InputFilename(cl::Positional, cl::desc("<input bitcode file>"),
40               cl::init("-"), cl::value_desc("filename"));
41 
42 static cl::opt<std::string>
43 OutputFilename("o", cl::desc("Specify output filename"),
44                cl::value_desc("filename"), cl::init("-"));
45 
46 static cl::opt<bool>
47 Force("f", cl::desc("Enable binary output on terminals"));
48 
49 static cl::opt<bool>
50 DeleteFn("delete", cl::desc("Delete specified Globals from Module"));
51 
52 static cl::opt<bool>
53     Recursive("recursive",
54               cl::desc("Recursively extract all called functions"));
55 
56 // ExtractFuncs - The functions to extract from the module.
57 static cl::list<std::string>
58 ExtractFuncs("func", cl::desc("Specify function to extract"),
59              cl::ZeroOrMore, cl::value_desc("function"));
60 
61 // ExtractRegExpFuncs - The functions, matched via regular expression, to
62 // extract from the module.
63 static cl::list<std::string>
64 ExtractRegExpFuncs("rfunc", cl::desc("Specify function(s) to extract using a "
65                                      "regular expression"),
66                    cl::ZeroOrMore, cl::value_desc("rfunction"));
67 
68 // ExtractBlocks - The blocks to extract from the module.
69 static cl::list<std::string>
70     ExtractBlocks("bb",
71                   cl::desc("Specify <function, basic block> pairs to extract"),
72                   cl::ZeroOrMore, cl::value_desc("function:bb"));
73 
74 // ExtractAlias - The alias to extract from the module.
75 static cl::list<std::string>
76 ExtractAliases("alias", cl::desc("Specify alias to extract"),
77                cl::ZeroOrMore, cl::value_desc("alias"));
78 
79 
80 // ExtractRegExpAliases - The aliases, matched via regular expression, to
81 // extract from the module.
82 static cl::list<std::string>
83 ExtractRegExpAliases("ralias", cl::desc("Specify alias(es) to extract using a "
84                                         "regular expression"),
85                      cl::ZeroOrMore, cl::value_desc("ralias"));
86 
87 // ExtractGlobals - The globals to extract from the module.
88 static cl::list<std::string>
89 ExtractGlobals("glob", cl::desc("Specify global to extract"),
90                cl::ZeroOrMore, cl::value_desc("global"));
91 
92 // ExtractRegExpGlobals - The globals, matched via regular expression, to
93 // extract from the module...
94 static cl::list<std::string>
95 ExtractRegExpGlobals("rglob", cl::desc("Specify global(s) to extract using a "
96                                        "regular expression"),
97                      cl::ZeroOrMore, cl::value_desc("rglobal"));
98 
99 static cl::opt<bool>
100 OutputAssembly("S",
101                cl::desc("Write output as LLVM assembly"), cl::Hidden);
102 
103 static cl::opt<bool> PreserveBitcodeUseListOrder(
104     "preserve-bc-uselistorder",
105     cl::desc("Preserve use-list order when writing LLVM bitcode."),
106     cl::init(true), cl::Hidden);
107 
108 static cl::opt<bool> PreserveAssemblyUseListOrder(
109     "preserve-ll-uselistorder",
110     cl::desc("Preserve use-list order when writing LLVM assembly."),
111     cl::init(false), cl::Hidden);
112 
main(int argc,char ** argv)113 int main(int argc, char **argv) {
114   InitLLVM X(argc, argv);
115 
116   LLVMContext Context;
117   cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n");
118 
119   // Use lazy loading, since we only care about selected global values.
120   SMDiagnostic Err;
121   std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context);
122 
123   if (!M.get()) {
124     Err.print(argv[0], errs());
125     return 1;
126   }
127 
128   // Use SetVector to avoid duplicates.
129   SetVector<GlobalValue *> GVs;
130 
131   // Figure out which aliases we should extract.
132   for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) {
133     GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]);
134     if (!GA) {
135       errs() << argv[0] << ": program doesn't contain alias named '"
136              << ExtractAliases[i] << "'!\n";
137       return 1;
138     }
139     GVs.insert(GA);
140   }
141 
142   // Extract aliases via regular expression matching.
143   for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) {
144     std::string Error;
145     Regex RegEx(ExtractRegExpAliases[i]);
146     if (!RegEx.isValid(Error)) {
147       errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' "
148         "invalid regex: " << Error;
149     }
150     bool match = false;
151     for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end();
152          GA != E; GA++) {
153       if (RegEx.match(GA->getName())) {
154         GVs.insert(&*GA);
155         match = true;
156       }
157     }
158     if (!match) {
159       errs() << argv[0] << ": program doesn't contain global named '"
160              << ExtractRegExpAliases[i] << "'!\n";
161       return 1;
162     }
163   }
164 
165   // Figure out which globals we should extract.
166   for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) {
167     GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]);
168     if (!GV) {
169       errs() << argv[0] << ": program doesn't contain global named '"
170              << ExtractGlobals[i] << "'!\n";
171       return 1;
172     }
173     GVs.insert(GV);
174   }
175 
176   // Extract globals via regular expression matching.
177   for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) {
178     std::string Error;
179     Regex RegEx(ExtractRegExpGlobals[i]);
180     if (!RegEx.isValid(Error)) {
181       errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' "
182         "invalid regex: " << Error;
183     }
184     bool match = false;
185     for (auto &GV : M->globals()) {
186       if (RegEx.match(GV.getName())) {
187         GVs.insert(&GV);
188         match = true;
189       }
190     }
191     if (!match) {
192       errs() << argv[0] << ": program doesn't contain global named '"
193              << ExtractRegExpGlobals[i] << "'!\n";
194       return 1;
195     }
196   }
197 
198   // Figure out which functions we should extract.
199   for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) {
200     GlobalValue *GV = M->getFunction(ExtractFuncs[i]);
201     if (!GV) {
202       errs() << argv[0] << ": program doesn't contain function named '"
203              << ExtractFuncs[i] << "'!\n";
204       return 1;
205     }
206     GVs.insert(GV);
207   }
208   // Extract functions via regular expression matching.
209   for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) {
210     std::string Error;
211     StringRef RegExStr = ExtractRegExpFuncs[i];
212     Regex RegEx(RegExStr);
213     if (!RegEx.isValid(Error)) {
214       errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' "
215         "invalid regex: " << Error;
216     }
217     bool match = false;
218     for (Module::iterator F = M->begin(), E = M->end(); F != E;
219          F++) {
220       if (RegEx.match(F->getName())) {
221         GVs.insert(&*F);
222         match = true;
223       }
224     }
225     if (!match) {
226       errs() << argv[0] << ": program doesn't contain global named '"
227              << ExtractRegExpFuncs[i] << "'!\n";
228       return 1;
229     }
230   }
231 
232   // Figure out which BasicBlocks we should extract.
233   SmallVector<BasicBlock *, 4> BBs;
234   for (StringRef StrPair : ExtractBlocks) {
235     auto BBInfo = StrPair.split(':');
236     // Get the function.
237     Function *F = M->getFunction(BBInfo.first);
238     if (!F) {
239       errs() << argv[0] << ": program doesn't contain a function named '"
240              << BBInfo.first << "'!\n";
241       return 1;
242     }
243     // Do not materialize this function.
244     GVs.insert(F);
245     // Get the basic block.
246     auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) {
247       return BB.getName().equals(BBInfo.second);
248     });
249     if (Res == F->end()) {
250       errs() << argv[0] << ": function " << F->getName()
251              << " doesn't contain a basic block named '" << BBInfo.second
252              << "'!\n";
253       return 1;
254     }
255     BBs.push_back(&*Res);
256   }
257 
258   // Use *argv instead of argv[0] to work around a wrong GCC warning.
259   ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: ");
260 
261   if (Recursive) {
262     std::vector<llvm::Function *> Workqueue;
263     for (GlobalValue *GV : GVs) {
264       if (auto *F = dyn_cast<Function>(GV)) {
265         Workqueue.push_back(F);
266       }
267     }
268     while (!Workqueue.empty()) {
269       Function *F = &*Workqueue.back();
270       Workqueue.pop_back();
271       ExitOnErr(F->materialize());
272       for (auto &BB : *F) {
273         for (auto &I : BB) {
274           auto *CI = dyn_cast<CallInst>(&I);
275           if (!CI)
276             continue;
277           Function *CF = CI->getCalledFunction();
278           if (!CF)
279             continue;
280           if (CF->isDeclaration() || GVs.count(CF))
281             continue;
282           GVs.insert(CF);
283           Workqueue.push_back(CF);
284         }
285       }
286     }
287   }
288 
289   auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); };
290 
291   // Materialize requisite global values.
292   if (!DeleteFn) {
293     for (size_t i = 0, e = GVs.size(); i != e; ++i)
294       Materialize(*GVs[i]);
295   } else {
296     // Deleting. Materialize every GV that's *not* in GVs.
297     SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end());
298     for (auto &F : *M) {
299       if (!GVSet.count(&F))
300         Materialize(F);
301     }
302   }
303 
304   {
305     std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end());
306     legacy::PassManager Extract;
307     Extract.add(createGVExtractionPass(Gvs, DeleteFn));
308     Extract.run(*M);
309 
310     // Now that we have all the GVs we want, mark the module as fully
311     // materialized.
312     // FIXME: should the GVExtractionPass handle this?
313     ExitOnErr(M->materializeAll());
314   }
315 
316   // Extract the specified basic blocks from the module and erase the existing
317   // functions.
318   if (!ExtractBlocks.empty()) {
319     legacy::PassManager PM;
320     PM.add(createBlockExtractorPass(BBs, true));
321     PM.run(*M);
322   }
323 
324   // In addition to deleting all other functions, we also want to spiff it
325   // up a little bit.  Do this now.
326   legacy::PassManager Passes;
327 
328   if (!DeleteFn)
329     Passes.add(createGlobalDCEPass());           // Delete unreachable globals
330   Passes.add(createStripDeadDebugInfoPass());    // Remove dead debug info
331   Passes.add(createStripDeadPrototypesPass());   // Remove dead func decls
332 
333   std::error_code EC;
334   ToolOutputFile Out(OutputFilename, EC, sys::fs::F_None);
335   if (EC) {
336     errs() << EC.message() << '\n';
337     return 1;
338   }
339 
340   if (OutputAssembly)
341     Passes.add(
342         createPrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder));
343   else if (Force || !CheckBitcodeOutputToConsole(Out.os(), true))
344     Passes.add(createBitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder));
345 
346   Passes.run(*M.get());
347 
348   // Declare success.
349   Out.keep();
350 
351   return 0;
352 }
353