1 //===- SymbolRewriter.cpp - Symbol Rewriter -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // SymbolRewriter is a LLVM pass which can rewrite symbols transparently within
10 // existing code.  It is implemented as a compiler pass and is configured via a
11 // YAML configuration file.
12 //
13 // The YAML configuration file format is as follows:
14 //
15 // RewriteMapFile := RewriteDescriptors
16 // RewriteDescriptors := RewriteDescriptor | RewriteDescriptors
17 // RewriteDescriptor := RewriteDescriptorType ':' '{' RewriteDescriptorFields '}'
18 // RewriteDescriptorFields := RewriteDescriptorField | RewriteDescriptorFields
19 // RewriteDescriptorField := FieldIdentifier ':' FieldValue ','
20 // RewriteDescriptorType := Identifier
21 // FieldIdentifier := Identifier
22 // FieldValue := Identifier
23 // Identifier := [0-9a-zA-Z]+
24 //
25 // Currently, the following descriptor types are supported:
26 //
27 // - function:          (function rewriting)
28 //      + Source        (original name of the function)
29 //      + Target        (explicit transformation)
30 //      + Transform     (pattern transformation)
31 //      + Naked         (boolean, whether the function is undecorated)
32 // - global variable:   (external linkage global variable rewriting)
33 //      + Source        (original name of externally visible variable)
34 //      + Target        (explicit transformation)
35 //      + Transform     (pattern transformation)
36 // - global alias:      (global alias rewriting)
37 //      + Source        (original name of the aliased name)
38 //      + Target        (explicit transformation)
39 //      + Transform     (pattern transformation)
40 //
41 // Note that source and exactly one of [Target, Transform] must be provided
42 //
43 // New rewrite descriptors can be created.  Addding a new rewrite descriptor
44 // involves:
45 //
46 //  a) extended the rewrite descriptor kind enumeration
47 //     (<anonymous>::RewriteDescriptor::RewriteDescriptorType)
48 //  b) implementing the new descriptor
49 //     (c.f. <anonymous>::ExplicitRewriteFunctionDescriptor)
50 //  c) extending the rewrite map parser
51 //     (<anonymous>::RewriteMapParser::parseEntry)
52 //
53 //  Specify to rewrite the symbols using the `-rewrite-symbols` option, and
54 //  specify the map file to use for the rewriting via the `-rewrite-map-file`
55 //  option.
56 //
57 //===----------------------------------------------------------------------===//
58 
59 #include "llvm/Transforms/Utils/SymbolRewriter.h"
60 #include "llvm/ADT/STLExtras.h"
61 #include "llvm/ADT/SmallString.h"
62 #include "llvm/ADT/StringRef.h"
63 #include "llvm/ADT/ilist.h"
64 #include "llvm/ADT/iterator_range.h"
65 #include "llvm/IR/Comdat.h"
66 #include "llvm/IR/Function.h"
67 #include "llvm/IR/GlobalAlias.h"
68 #include "llvm/IR/GlobalObject.h"
69 #include "llvm/IR/GlobalVariable.h"
70 #include "llvm/IR/Module.h"
71 #include "llvm/IR/Value.h"
72 #include "llvm/InitializePasses.h"
73 #include "llvm/Pass.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/ErrorHandling.h"
77 #include "llvm/Support/ErrorOr.h"
78 #include "llvm/Support/MemoryBuffer.h"
79 #include "llvm/Support/Regex.h"
80 #include "llvm/Support/SourceMgr.h"
81 #include "llvm/Support/YAMLParser.h"
82 #include <memory>
83 #include <string>
84 #include <vector>
85 
86 using namespace llvm;
87 using namespace SymbolRewriter;
88 
89 #define DEBUG_TYPE "symbol-rewriter"
90 
91 static cl::list<std::string> RewriteMapFiles("rewrite-map-file",
92                                              cl::desc("Symbol Rewrite Map"),
93                                              cl::value_desc("filename"),
94                                              cl::Hidden);
95 
rewriteComdat(Module & M,GlobalObject * GO,const std::string & Source,const std::string & Target)96 static void rewriteComdat(Module &M, GlobalObject *GO,
97                           const std::string &Source,
98                           const std::string &Target) {
99   if (Comdat *CD = GO->getComdat()) {
100     auto &Comdats = M.getComdatSymbolTable();
101 
102     Comdat *C = M.getOrInsertComdat(Target);
103     C->setSelectionKind(CD->getSelectionKind());
104     GO->setComdat(C);
105 
106     Comdats.erase(Comdats.find(Source));
107   }
108 }
109 
110 namespace {
111 
112 template <RewriteDescriptor::Type DT, typename ValueType,
113           ValueType *(Module::*Get)(StringRef) const>
114 class ExplicitRewriteDescriptor : public RewriteDescriptor {
115 public:
116   const std::string Source;
117   const std::string Target;
118 
ExplicitRewriteDescriptor(StringRef S,StringRef T,const bool Naked)119   ExplicitRewriteDescriptor(StringRef S, StringRef T, const bool Naked)
120       : RewriteDescriptor(DT), Source(Naked ? StringRef("\01" + S.str()) : S),
121         Target(T) {}
122 
123   bool performOnModule(Module &M) override;
124 
classof(const RewriteDescriptor * RD)125   static bool classof(const RewriteDescriptor *RD) {
126     return RD->getType() == DT;
127   }
128 };
129 
130 } // end anonymous namespace
131 
132 template <RewriteDescriptor::Type DT, typename ValueType,
133           ValueType *(Module::*Get)(StringRef) const>
performOnModule(Module & M)134 bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) {
135   bool Changed = false;
136   if (ValueType *S = (M.*Get)(Source)) {
137     if (GlobalObject *GO = dyn_cast<GlobalObject>(S))
138       rewriteComdat(M, GO, Source, Target);
139 
140     if (Value *T = (M.*Get)(Target))
141       S->setValueName(T->getValueName());
142     else
143       S->setName(Target);
144 
145     Changed = true;
146   }
147   return Changed;
148 }
149 
150 namespace {
151 
152 template <RewriteDescriptor::Type DT, typename ValueType,
153           ValueType *(Module::*Get)(StringRef) const,
154           iterator_range<typename iplist<ValueType>::iterator>
155           (Module::*Iterator)()>
156 class PatternRewriteDescriptor : public RewriteDescriptor {
157 public:
158   const std::string Pattern;
159   const std::string Transform;
160 
PatternRewriteDescriptor(StringRef P,StringRef T)161   PatternRewriteDescriptor(StringRef P, StringRef T)
162     : RewriteDescriptor(DT), Pattern(P), Transform(T) { }
163 
164   bool performOnModule(Module &M) override;
165 
classof(const RewriteDescriptor * RD)166   static bool classof(const RewriteDescriptor *RD) {
167     return RD->getType() == DT;
168   }
169 };
170 
171 } // end anonymous namespace
172 
173 template <RewriteDescriptor::Type DT, typename ValueType,
174           ValueType *(Module::*Get)(StringRef) const,
175           iterator_range<typename iplist<ValueType>::iterator>
176           (Module::*Iterator)()>
177 bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>::
performOnModule(Module & M)178 performOnModule(Module &M) {
179   bool Changed = false;
180   for (auto &C : (M.*Iterator)()) {
181     std::string Error;
182 
183     std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error);
184     if (!Error.empty())
185       report_fatal_error("unable to transforn " + C.getName() + " in " +
186                          M.getModuleIdentifier() + ": " + Error);
187 
188     if (C.getName() == Name)
189       continue;
190 
191     if (GlobalObject *GO = dyn_cast<GlobalObject>(&C))
192       rewriteComdat(M, GO, C.getName(), Name);
193 
194     if (Value *V = (M.*Get)(Name))
195       C.setValueName(V->getValueName());
196     else
197       C.setName(Name);
198 
199     Changed = true;
200   }
201   return Changed;
202 }
203 
204 namespace {
205 
206 /// Represents a rewrite for an explicitly named (function) symbol.  Both the
207 /// source function name and target function name of the transformation are
208 /// explicitly spelt out.
209 using ExplicitRewriteFunctionDescriptor =
210     ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function, Function,
211                               &Module::getFunction>;
212 
213 /// Represents a rewrite for an explicitly named (global variable) symbol.  Both
214 /// the source variable name and target variable name are spelt out.  This
215 /// applies only to module level variables.
216 using ExplicitRewriteGlobalVariableDescriptor =
217     ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
218                               GlobalVariable, &Module::getGlobalVariable>;
219 
220 /// Represents a rewrite for an explicitly named global alias.  Both the source
221 /// and target name are explicitly spelt out.
222 using ExplicitRewriteNamedAliasDescriptor =
223     ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, GlobalAlias,
224                               &Module::getNamedAlias>;
225 
226 /// Represents a rewrite for a regular expression based pattern for functions.
227 /// A pattern for the function name is provided and a transformation for that
228 /// pattern to determine the target function name create the rewrite rule.
229 using PatternRewriteFunctionDescriptor =
230     PatternRewriteDescriptor<RewriteDescriptor::Type::Function, Function,
231                              &Module::getFunction, &Module::functions>;
232 
233 /// Represents a rewrite for a global variable based upon a matching pattern.
234 /// Each global variable matching the provided pattern will be transformed as
235 /// described in the transformation pattern for the target.  Applies only to
236 /// module level variables.
237 using PatternRewriteGlobalVariableDescriptor =
238     PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
239                              GlobalVariable, &Module::getGlobalVariable,
240                              &Module::globals>;
241 
242 /// PatternRewriteNamedAliasDescriptor - represents a rewrite for global
243 /// aliases which match a given pattern.  The provided transformation will be
244 /// applied to each of the matching names.
245 using PatternRewriteNamedAliasDescriptor =
246     PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, GlobalAlias,
247                              &Module::getNamedAlias, &Module::aliases>;
248 
249 } // end anonymous namespace
250 
parse(const std::string & MapFile,RewriteDescriptorList * DL)251 bool RewriteMapParser::parse(const std::string &MapFile,
252                              RewriteDescriptorList *DL) {
253   ErrorOr<std::unique_ptr<MemoryBuffer>> Mapping =
254       MemoryBuffer::getFile(MapFile);
255 
256   if (!Mapping)
257     report_fatal_error("unable to read rewrite map '" + MapFile + "': " +
258                        Mapping.getError().message());
259 
260   if (!parse(*Mapping, DL))
261     report_fatal_error("unable to parse rewrite map '" + MapFile + "'");
262 
263   return true;
264 }
265 
parse(std::unique_ptr<MemoryBuffer> & MapFile,RewriteDescriptorList * DL)266 bool RewriteMapParser::parse(std::unique_ptr<MemoryBuffer> &MapFile,
267                              RewriteDescriptorList *DL) {
268   SourceMgr SM;
269   yaml::Stream YS(MapFile->getBuffer(), SM);
270 
271   for (auto &Document : YS) {
272     yaml::MappingNode *DescriptorList;
273 
274     // ignore empty documents
275     if (isa<yaml::NullNode>(Document.getRoot()))
276       continue;
277 
278     DescriptorList = dyn_cast<yaml::MappingNode>(Document.getRoot());
279     if (!DescriptorList) {
280       YS.printError(Document.getRoot(), "DescriptorList node must be a map");
281       return false;
282     }
283 
284     for (auto &Descriptor : *DescriptorList)
285       if (!parseEntry(YS, Descriptor, DL))
286         return false;
287   }
288 
289   return true;
290 }
291 
parseEntry(yaml::Stream & YS,yaml::KeyValueNode & Entry,RewriteDescriptorList * DL)292 bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry,
293                                   RewriteDescriptorList *DL) {
294   yaml::ScalarNode *Key;
295   yaml::MappingNode *Value;
296   SmallString<32> KeyStorage;
297   StringRef RewriteType;
298 
299   Key = dyn_cast<yaml::ScalarNode>(Entry.getKey());
300   if (!Key) {
301     YS.printError(Entry.getKey(), "rewrite type must be a scalar");
302     return false;
303   }
304 
305   Value = dyn_cast<yaml::MappingNode>(Entry.getValue());
306   if (!Value) {
307     YS.printError(Entry.getValue(), "rewrite descriptor must be a map");
308     return false;
309   }
310 
311   RewriteType = Key->getValue(KeyStorage);
312   if (RewriteType.equals("function"))
313     return parseRewriteFunctionDescriptor(YS, Key, Value, DL);
314   else if (RewriteType.equals("global variable"))
315     return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL);
316   else if (RewriteType.equals("global alias"))
317     return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL);
318 
319   YS.printError(Entry.getKey(), "unknown rewrite type");
320   return false;
321 }
322 
323 bool RewriteMapParser::
parseRewriteFunctionDescriptor(yaml::Stream & YS,yaml::ScalarNode * K,yaml::MappingNode * Descriptor,RewriteDescriptorList * DL)324 parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
325                                yaml::MappingNode *Descriptor,
326                                RewriteDescriptorList *DL) {
327   bool Naked = false;
328   std::string Source;
329   std::string Target;
330   std::string Transform;
331 
332   for (auto &Field : *Descriptor) {
333     yaml::ScalarNode *Key;
334     yaml::ScalarNode *Value;
335     SmallString<32> KeyStorage;
336     SmallString<32> ValueStorage;
337     StringRef KeyValue;
338 
339     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
340     if (!Key) {
341       YS.printError(Field.getKey(), "descriptor key must be a scalar");
342       return false;
343     }
344 
345     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
346     if (!Value) {
347       YS.printError(Field.getValue(), "descriptor value must be a scalar");
348       return false;
349     }
350 
351     KeyValue = Key->getValue(KeyStorage);
352     if (KeyValue.equals("source")) {
353       std::string Error;
354 
355       Source = Value->getValue(ValueStorage);
356       if (!Regex(Source).isValid(Error)) {
357         YS.printError(Field.getKey(), "invalid regex: " + Error);
358         return false;
359       }
360     } else if (KeyValue.equals("target")) {
361       Target = Value->getValue(ValueStorage);
362     } else if (KeyValue.equals("transform")) {
363       Transform = Value->getValue(ValueStorage);
364     } else if (KeyValue.equals("naked")) {
365       std::string Undecorated;
366 
367       Undecorated = Value->getValue(ValueStorage);
368       Naked = StringRef(Undecorated).lower() == "true" || Undecorated == "1";
369     } else {
370       YS.printError(Field.getKey(), "unknown key for function");
371       return false;
372     }
373   }
374 
375   if (Transform.empty() == Target.empty()) {
376     YS.printError(Descriptor,
377                   "exactly one of transform or target must be specified");
378     return false;
379   }
380 
381   // TODO see if there is a more elegant solution to selecting the rewrite
382   // descriptor type
383   if (!Target.empty())
384     DL->push_back(std::make_unique<ExplicitRewriteFunctionDescriptor>(
385         Source, Target, Naked));
386   else
387     DL->push_back(
388         std::make_unique<PatternRewriteFunctionDescriptor>(Source, Transform));
389 
390   return true;
391 }
392 
393 bool RewriteMapParser::
parseRewriteGlobalVariableDescriptor(yaml::Stream & YS,yaml::ScalarNode * K,yaml::MappingNode * Descriptor,RewriteDescriptorList * DL)394 parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
395                                      yaml::MappingNode *Descriptor,
396                                      RewriteDescriptorList *DL) {
397   std::string Source;
398   std::string Target;
399   std::string Transform;
400 
401   for (auto &Field : *Descriptor) {
402     yaml::ScalarNode *Key;
403     yaml::ScalarNode *Value;
404     SmallString<32> KeyStorage;
405     SmallString<32> ValueStorage;
406     StringRef KeyValue;
407 
408     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
409     if (!Key) {
410       YS.printError(Field.getKey(), "descriptor Key must be a scalar");
411       return false;
412     }
413 
414     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
415     if (!Value) {
416       YS.printError(Field.getValue(), "descriptor value must be a scalar");
417       return false;
418     }
419 
420     KeyValue = Key->getValue(KeyStorage);
421     if (KeyValue.equals("source")) {
422       std::string Error;
423 
424       Source = Value->getValue(ValueStorage);
425       if (!Regex(Source).isValid(Error)) {
426         YS.printError(Field.getKey(), "invalid regex: " + Error);
427         return false;
428       }
429     } else if (KeyValue.equals("target")) {
430       Target = Value->getValue(ValueStorage);
431     } else if (KeyValue.equals("transform")) {
432       Transform = Value->getValue(ValueStorage);
433     } else {
434       YS.printError(Field.getKey(), "unknown Key for Global Variable");
435       return false;
436     }
437   }
438 
439   if (Transform.empty() == Target.empty()) {
440     YS.printError(Descriptor,
441                   "exactly one of transform or target must be specified");
442     return false;
443   }
444 
445   if (!Target.empty())
446     DL->push_back(std::make_unique<ExplicitRewriteGlobalVariableDescriptor>(
447         Source, Target,
448         /*Naked*/ false));
449   else
450     DL->push_back(std::make_unique<PatternRewriteGlobalVariableDescriptor>(
451         Source, Transform));
452 
453   return true;
454 }
455 
456 bool RewriteMapParser::
parseRewriteGlobalAliasDescriptor(yaml::Stream & YS,yaml::ScalarNode * K,yaml::MappingNode * Descriptor,RewriteDescriptorList * DL)457 parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
458                                   yaml::MappingNode *Descriptor,
459                                   RewriteDescriptorList *DL) {
460   std::string Source;
461   std::string Target;
462   std::string Transform;
463 
464   for (auto &Field : *Descriptor) {
465     yaml::ScalarNode *Key;
466     yaml::ScalarNode *Value;
467     SmallString<32> KeyStorage;
468     SmallString<32> ValueStorage;
469     StringRef KeyValue;
470 
471     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
472     if (!Key) {
473       YS.printError(Field.getKey(), "descriptor key must be a scalar");
474       return false;
475     }
476 
477     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
478     if (!Value) {
479       YS.printError(Field.getValue(), "descriptor value must be a scalar");
480       return false;
481     }
482 
483     KeyValue = Key->getValue(KeyStorage);
484     if (KeyValue.equals("source")) {
485       std::string Error;
486 
487       Source = Value->getValue(ValueStorage);
488       if (!Regex(Source).isValid(Error)) {
489         YS.printError(Field.getKey(), "invalid regex: " + Error);
490         return false;
491       }
492     } else if (KeyValue.equals("target")) {
493       Target = Value->getValue(ValueStorage);
494     } else if (KeyValue.equals("transform")) {
495       Transform = Value->getValue(ValueStorage);
496     } else {
497       YS.printError(Field.getKey(), "unknown key for Global Alias");
498       return false;
499     }
500   }
501 
502   if (Transform.empty() == Target.empty()) {
503     YS.printError(Descriptor,
504                   "exactly one of transform or target must be specified");
505     return false;
506   }
507 
508   if (!Target.empty())
509     DL->push_back(std::make_unique<ExplicitRewriteNamedAliasDescriptor>(
510         Source, Target,
511         /*Naked*/ false));
512   else
513     DL->push_back(std::make_unique<PatternRewriteNamedAliasDescriptor>(
514         Source, Transform));
515 
516   return true;
517 }
518 
519 namespace {
520 
521 class RewriteSymbolsLegacyPass : public ModulePass {
522 public:
523   static char ID; // Pass identification, replacement for typeid
524 
525   RewriteSymbolsLegacyPass();
526   RewriteSymbolsLegacyPass(SymbolRewriter::RewriteDescriptorList &DL);
527 
528   bool runOnModule(Module &M) override;
529 
530 private:
531   RewriteSymbolPass Impl;
532 };
533 
534 } // end anonymous namespace
535 
536 char RewriteSymbolsLegacyPass::ID = 0;
537 
RewriteSymbolsLegacyPass()538 RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass() : ModulePass(ID) {
539   initializeRewriteSymbolsLegacyPassPass(*PassRegistry::getPassRegistry());
540 }
541 
RewriteSymbolsLegacyPass(SymbolRewriter::RewriteDescriptorList & DL)542 RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass(
543     SymbolRewriter::RewriteDescriptorList &DL)
544     : ModulePass(ID), Impl(DL) {}
545 
runOnModule(Module & M)546 bool RewriteSymbolsLegacyPass::runOnModule(Module &M) {
547   return Impl.runImpl(M);
548 }
549 
run(Module & M,ModuleAnalysisManager & AM)550 PreservedAnalyses RewriteSymbolPass::run(Module &M, ModuleAnalysisManager &AM) {
551   if (!runImpl(M))
552     return PreservedAnalyses::all();
553 
554   return PreservedAnalyses::none();
555 }
556 
runImpl(Module & M)557 bool RewriteSymbolPass::runImpl(Module &M) {
558   bool Changed;
559 
560   Changed = false;
561   for (auto &Descriptor : Descriptors)
562     Changed |= Descriptor->performOnModule(M);
563 
564   return Changed;
565 }
566 
loadAndParseMapFiles()567 void RewriteSymbolPass::loadAndParseMapFiles() {
568   const std::vector<std::string> MapFiles(RewriteMapFiles);
569   SymbolRewriter::RewriteMapParser Parser;
570 
571   for (const auto &MapFile : MapFiles)
572     Parser.parse(MapFile, &Descriptors);
573 }
574 
575 INITIALIZE_PASS(RewriteSymbolsLegacyPass, "rewrite-symbols", "Rewrite Symbols",
576                 false, false)
577 
createRewriteSymbolsPass()578 ModulePass *llvm::createRewriteSymbolsPass() {
579   return new RewriteSymbolsLegacyPass();
580 }
581 
582 ModulePass *
createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList & DL)583 llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) {
584   return new RewriteSymbolsLegacyPass(DL);
585 }
586