1 //===- SymbolRewriter.cpp - Symbol Rewriter -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // SymbolRewriter is a LLVM pass which can rewrite symbols transparently within
10 // existing code.  It is implemented as a compiler pass and is configured via a
11 // YAML configuration file.
12 //
13 // The YAML configuration file format is as follows:
14 //
15 // RewriteMapFile := RewriteDescriptors
16 // RewriteDescriptors := RewriteDescriptor | RewriteDescriptors
17 // RewriteDescriptor := RewriteDescriptorType ':' '{' RewriteDescriptorFields '}'
18 // RewriteDescriptorFields := RewriteDescriptorField | RewriteDescriptorFields
19 // RewriteDescriptorField := FieldIdentifier ':' FieldValue ','
20 // RewriteDescriptorType := Identifier
21 // FieldIdentifier := Identifier
22 // FieldValue := Identifier
23 // Identifier := [0-9a-zA-Z]+
24 //
25 // Currently, the following descriptor types are supported:
26 //
27 // - function:          (function rewriting)
28 //      + Source        (original name of the function)
29 //      + Target        (explicit transformation)
30 //      + Transform     (pattern transformation)
31 //      + Naked         (boolean, whether the function is undecorated)
32 // - global variable:   (external linkage global variable rewriting)
33 //      + Source        (original name of externally visible variable)
34 //      + Target        (explicit transformation)
35 //      + Transform     (pattern transformation)
36 // - global alias:      (global alias rewriting)
37 //      + Source        (original name of the aliased name)
38 //      + Target        (explicit transformation)
39 //      + Transform     (pattern transformation)
40 //
41 // Note that source and exactly one of [Target, Transform] must be provided
42 //
43 // New rewrite descriptors can be created.  Addding a new rewrite descriptor
44 // involves:
45 //
46 //  a) extended the rewrite descriptor kind enumeration
47 //     (<anonymous>::RewriteDescriptor::RewriteDescriptorType)
48 //  b) implementing the new descriptor
49 //     (c.f. <anonymous>::ExplicitRewriteFunctionDescriptor)
50 //  c) extending the rewrite map parser
51 //     (<anonymous>::RewriteMapParser::parseEntry)
52 //
53 //  Specify to rewrite the symbols using the `-rewrite-symbols` option, and
54 //  specify the map file to use for the rewriting via the `-rewrite-map-file`
55 //  option.
56 //
57 //===----------------------------------------------------------------------===//
58 
59 #include "llvm/Transforms/Utils/SymbolRewriter.h"
60 #include "llvm/ADT/STLExtras.h"
61 #include "llvm/ADT/SmallString.h"
62 #include "llvm/ADT/StringRef.h"
63 #include "llvm/ADT/ilist.h"
64 #include "llvm/ADT/iterator_range.h"
65 #include "llvm/IR/Comdat.h"
66 #include "llvm/IR/Function.h"
67 #include "llvm/IR/GlobalAlias.h"
68 #include "llvm/IR/GlobalObject.h"
69 #include "llvm/IR/GlobalVariable.h"
70 #include "llvm/IR/Module.h"
71 #include "llvm/IR/Value.h"
72 #include "llvm/InitializePasses.h"
73 #include "llvm/Pass.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/ErrorHandling.h"
77 #include "llvm/Support/ErrorOr.h"
78 #include "llvm/Support/MemoryBuffer.h"
79 #include "llvm/Support/Regex.h"
80 #include "llvm/Support/SourceMgr.h"
81 #include "llvm/Support/YAMLParser.h"
82 #include <memory>
83 #include <string>
84 #include <vector>
85 
86 using namespace llvm;
87 using namespace SymbolRewriter;
88 
89 #define DEBUG_TYPE "symbol-rewriter"
90 
91 static cl::list<std::string> RewriteMapFiles("rewrite-map-file",
92                                              cl::desc("Symbol Rewrite Map"),
93                                              cl::value_desc("filename"),
94                                              cl::Hidden);
95 
96 static void rewriteComdat(Module &M, GlobalObject *GO,
97                           const std::string &Source,
98                           const std::string &Target) {
99   if (Comdat *CD = GO->getComdat()) {
100     auto &Comdats = M.getComdatSymbolTable();
101 
102     Comdat *C = M.getOrInsertComdat(Target);
103     C->setSelectionKind(CD->getSelectionKind());
104     GO->setComdat(C);
105 
106     Comdats.erase(Comdats.find(Source));
107   }
108 }
109 
110 namespace {
111 
112 template <RewriteDescriptor::Type DT, typename ValueType,
113           ValueType *(Module::*Get)(StringRef) const>
114 class ExplicitRewriteDescriptor : public RewriteDescriptor {
115 public:
116   const std::string Source;
117   const std::string Target;
118 
119   ExplicitRewriteDescriptor(StringRef S, StringRef T, const bool Naked)
120       : RewriteDescriptor(DT),
121         Source(std::string(Naked ? StringRef("\01" + S.str()) : S)),
122         Target(std::string(T)) {}
123 
124   bool performOnModule(Module &M) override;
125 
126   static bool classof(const RewriteDescriptor *RD) {
127     return RD->getType() == DT;
128   }
129 };
130 
131 } // end anonymous namespace
132 
133 template <RewriteDescriptor::Type DT, typename ValueType,
134           ValueType *(Module::*Get)(StringRef) const>
135 bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) {
136   bool Changed = false;
137   if (ValueType *S = (M.*Get)(Source)) {
138     if (GlobalObject *GO = dyn_cast<GlobalObject>(S))
139       rewriteComdat(M, GO, Source, Target);
140 
141     if (Value *T = (M.*Get)(Target))
142       S->setValueName(T->getValueName());
143     else
144       S->setName(Target);
145 
146     Changed = true;
147   }
148   return Changed;
149 }
150 
151 namespace {
152 
153 template <RewriteDescriptor::Type DT, typename ValueType,
154           ValueType *(Module::*Get)(StringRef) const,
155           iterator_range<typename iplist<ValueType>::iterator>
156           (Module::*Iterator)()>
157 class PatternRewriteDescriptor : public RewriteDescriptor {
158 public:
159   const std::string Pattern;
160   const std::string Transform;
161 
162   PatternRewriteDescriptor(StringRef P, StringRef T)
163       : RewriteDescriptor(DT), Pattern(std::string(P)),
164         Transform(std::string(T)) {}
165 
166   bool performOnModule(Module &M) override;
167 
168   static bool classof(const RewriteDescriptor *RD) {
169     return RD->getType() == DT;
170   }
171 };
172 
173 } // end anonymous namespace
174 
175 template <RewriteDescriptor::Type DT, typename ValueType,
176           ValueType *(Module::*Get)(StringRef) const,
177           iterator_range<typename iplist<ValueType>::iterator>
178           (Module::*Iterator)()>
179 bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>::
180 performOnModule(Module &M) {
181   bool Changed = false;
182   for (auto &C : (M.*Iterator)()) {
183     std::string Error;
184 
185     std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error);
186     if (!Error.empty())
187       report_fatal_error("unable to transforn " + C.getName() + " in " +
188                          M.getModuleIdentifier() + ": " + Error);
189 
190     if (C.getName() == Name)
191       continue;
192 
193     if (GlobalObject *GO = dyn_cast<GlobalObject>(&C))
194       rewriteComdat(M, GO, std::string(C.getName()), Name);
195 
196     if (Value *V = (M.*Get)(Name))
197       C.setValueName(V->getValueName());
198     else
199       C.setName(Name);
200 
201     Changed = true;
202   }
203   return Changed;
204 }
205 
206 namespace {
207 
208 /// Represents a rewrite for an explicitly named (function) symbol.  Both the
209 /// source function name and target function name of the transformation are
210 /// explicitly spelt out.
211 using ExplicitRewriteFunctionDescriptor =
212     ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function, Function,
213                               &Module::getFunction>;
214 
215 /// Represents a rewrite for an explicitly named (global variable) symbol.  Both
216 /// the source variable name and target variable name are spelt out.  This
217 /// applies only to module level variables.
218 using ExplicitRewriteGlobalVariableDescriptor =
219     ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
220                               GlobalVariable, &Module::getGlobalVariable>;
221 
222 /// Represents a rewrite for an explicitly named global alias.  Both the source
223 /// and target name are explicitly spelt out.
224 using ExplicitRewriteNamedAliasDescriptor =
225     ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, GlobalAlias,
226                               &Module::getNamedAlias>;
227 
228 /// Represents a rewrite for a regular expression based pattern for functions.
229 /// A pattern for the function name is provided and a transformation for that
230 /// pattern to determine the target function name create the rewrite rule.
231 using PatternRewriteFunctionDescriptor =
232     PatternRewriteDescriptor<RewriteDescriptor::Type::Function, Function,
233                              &Module::getFunction, &Module::functions>;
234 
235 /// Represents a rewrite for a global variable based upon a matching pattern.
236 /// Each global variable matching the provided pattern will be transformed as
237 /// described in the transformation pattern for the target.  Applies only to
238 /// module level variables.
239 using PatternRewriteGlobalVariableDescriptor =
240     PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
241                              GlobalVariable, &Module::getGlobalVariable,
242                              &Module::globals>;
243 
244 /// PatternRewriteNamedAliasDescriptor - represents a rewrite for global
245 /// aliases which match a given pattern.  The provided transformation will be
246 /// applied to each of the matching names.
247 using PatternRewriteNamedAliasDescriptor =
248     PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, GlobalAlias,
249                              &Module::getNamedAlias, &Module::aliases>;
250 
251 } // end anonymous namespace
252 
253 bool RewriteMapParser::parse(const std::string &MapFile,
254                              RewriteDescriptorList *DL) {
255   ErrorOr<std::unique_ptr<MemoryBuffer>> Mapping =
256       MemoryBuffer::getFile(MapFile);
257 
258   if (!Mapping)
259     report_fatal_error("unable to read rewrite map '" + MapFile + "': " +
260                        Mapping.getError().message());
261 
262   if (!parse(*Mapping, DL))
263     report_fatal_error("unable to parse rewrite map '" + MapFile + "'");
264 
265   return true;
266 }
267 
268 bool RewriteMapParser::parse(std::unique_ptr<MemoryBuffer> &MapFile,
269                              RewriteDescriptorList *DL) {
270   SourceMgr SM;
271   yaml::Stream YS(MapFile->getBuffer(), SM);
272 
273   for (auto &Document : YS) {
274     yaml::MappingNode *DescriptorList;
275 
276     // ignore empty documents
277     if (isa<yaml::NullNode>(Document.getRoot()))
278       continue;
279 
280     DescriptorList = dyn_cast<yaml::MappingNode>(Document.getRoot());
281     if (!DescriptorList) {
282       YS.printError(Document.getRoot(), "DescriptorList node must be a map");
283       return false;
284     }
285 
286     for (auto &Descriptor : *DescriptorList)
287       if (!parseEntry(YS, Descriptor, DL))
288         return false;
289   }
290 
291   return true;
292 }
293 
294 bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry,
295                                   RewriteDescriptorList *DL) {
296   yaml::ScalarNode *Key;
297   yaml::MappingNode *Value;
298   SmallString<32> KeyStorage;
299   StringRef RewriteType;
300 
301   Key = dyn_cast<yaml::ScalarNode>(Entry.getKey());
302   if (!Key) {
303     YS.printError(Entry.getKey(), "rewrite type must be a scalar");
304     return false;
305   }
306 
307   Value = dyn_cast<yaml::MappingNode>(Entry.getValue());
308   if (!Value) {
309     YS.printError(Entry.getValue(), "rewrite descriptor must be a map");
310     return false;
311   }
312 
313   RewriteType = Key->getValue(KeyStorage);
314   if (RewriteType.equals("function"))
315     return parseRewriteFunctionDescriptor(YS, Key, Value, DL);
316   else if (RewriteType.equals("global variable"))
317     return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL);
318   else if (RewriteType.equals("global alias"))
319     return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL);
320 
321   YS.printError(Entry.getKey(), "unknown rewrite type");
322   return false;
323 }
324 
325 bool RewriteMapParser::
326 parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
327                                yaml::MappingNode *Descriptor,
328                                RewriteDescriptorList *DL) {
329   bool Naked = false;
330   std::string Source;
331   std::string Target;
332   std::string Transform;
333 
334   for (auto &Field : *Descriptor) {
335     yaml::ScalarNode *Key;
336     yaml::ScalarNode *Value;
337     SmallString<32> KeyStorage;
338     SmallString<32> ValueStorage;
339     StringRef KeyValue;
340 
341     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
342     if (!Key) {
343       YS.printError(Field.getKey(), "descriptor key must be a scalar");
344       return false;
345     }
346 
347     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
348     if (!Value) {
349       YS.printError(Field.getValue(), "descriptor value must be a scalar");
350       return false;
351     }
352 
353     KeyValue = Key->getValue(KeyStorage);
354     if (KeyValue.equals("source")) {
355       std::string Error;
356 
357       Source = std::string(Value->getValue(ValueStorage));
358       if (!Regex(Source).isValid(Error)) {
359         YS.printError(Field.getKey(), "invalid regex: " + Error);
360         return false;
361       }
362     } else if (KeyValue.equals("target")) {
363       Target = std::string(Value->getValue(ValueStorage));
364     } else if (KeyValue.equals("transform")) {
365       Transform = std::string(Value->getValue(ValueStorage));
366     } else if (KeyValue.equals("naked")) {
367       std::string Undecorated;
368 
369       Undecorated = std::string(Value->getValue(ValueStorage));
370       Naked = StringRef(Undecorated).lower() == "true" || Undecorated == "1";
371     } else {
372       YS.printError(Field.getKey(), "unknown key for function");
373       return false;
374     }
375   }
376 
377   if (Transform.empty() == Target.empty()) {
378     YS.printError(Descriptor,
379                   "exactly one of transform or target must be specified");
380     return false;
381   }
382 
383   // TODO see if there is a more elegant solution to selecting the rewrite
384   // descriptor type
385   if (!Target.empty())
386     DL->push_back(std::make_unique<ExplicitRewriteFunctionDescriptor>(
387         Source, Target, Naked));
388   else
389     DL->push_back(
390         std::make_unique<PatternRewriteFunctionDescriptor>(Source, Transform));
391 
392   return true;
393 }
394 
395 bool RewriteMapParser::
396 parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
397                                      yaml::MappingNode *Descriptor,
398                                      RewriteDescriptorList *DL) {
399   std::string Source;
400   std::string Target;
401   std::string Transform;
402 
403   for (auto &Field : *Descriptor) {
404     yaml::ScalarNode *Key;
405     yaml::ScalarNode *Value;
406     SmallString<32> KeyStorage;
407     SmallString<32> ValueStorage;
408     StringRef KeyValue;
409 
410     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
411     if (!Key) {
412       YS.printError(Field.getKey(), "descriptor Key must be a scalar");
413       return false;
414     }
415 
416     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
417     if (!Value) {
418       YS.printError(Field.getValue(), "descriptor value must be a scalar");
419       return false;
420     }
421 
422     KeyValue = Key->getValue(KeyStorage);
423     if (KeyValue.equals("source")) {
424       std::string Error;
425 
426       Source = std::string(Value->getValue(ValueStorage));
427       if (!Regex(Source).isValid(Error)) {
428         YS.printError(Field.getKey(), "invalid regex: " + Error);
429         return false;
430       }
431     } else if (KeyValue.equals("target")) {
432       Target = std::string(Value->getValue(ValueStorage));
433     } else if (KeyValue.equals("transform")) {
434       Transform = std::string(Value->getValue(ValueStorage));
435     } else {
436       YS.printError(Field.getKey(), "unknown Key for Global Variable");
437       return false;
438     }
439   }
440 
441   if (Transform.empty() == Target.empty()) {
442     YS.printError(Descriptor,
443                   "exactly one of transform or target must be specified");
444     return false;
445   }
446 
447   if (!Target.empty())
448     DL->push_back(std::make_unique<ExplicitRewriteGlobalVariableDescriptor>(
449         Source, Target,
450         /*Naked*/ false));
451   else
452     DL->push_back(std::make_unique<PatternRewriteGlobalVariableDescriptor>(
453         Source, Transform));
454 
455   return true;
456 }
457 
458 bool RewriteMapParser::
459 parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
460                                   yaml::MappingNode *Descriptor,
461                                   RewriteDescriptorList *DL) {
462   std::string Source;
463   std::string Target;
464   std::string Transform;
465 
466   for (auto &Field : *Descriptor) {
467     yaml::ScalarNode *Key;
468     yaml::ScalarNode *Value;
469     SmallString<32> KeyStorage;
470     SmallString<32> ValueStorage;
471     StringRef KeyValue;
472 
473     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
474     if (!Key) {
475       YS.printError(Field.getKey(), "descriptor key must be a scalar");
476       return false;
477     }
478 
479     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
480     if (!Value) {
481       YS.printError(Field.getValue(), "descriptor value must be a scalar");
482       return false;
483     }
484 
485     KeyValue = Key->getValue(KeyStorage);
486     if (KeyValue.equals("source")) {
487       std::string Error;
488 
489       Source = std::string(Value->getValue(ValueStorage));
490       if (!Regex(Source).isValid(Error)) {
491         YS.printError(Field.getKey(), "invalid regex: " + Error);
492         return false;
493       }
494     } else if (KeyValue.equals("target")) {
495       Target = std::string(Value->getValue(ValueStorage));
496     } else if (KeyValue.equals("transform")) {
497       Transform = std::string(Value->getValue(ValueStorage));
498     } else {
499       YS.printError(Field.getKey(), "unknown key for Global Alias");
500       return false;
501     }
502   }
503 
504   if (Transform.empty() == Target.empty()) {
505     YS.printError(Descriptor,
506                   "exactly one of transform or target must be specified");
507     return false;
508   }
509 
510   if (!Target.empty())
511     DL->push_back(std::make_unique<ExplicitRewriteNamedAliasDescriptor>(
512         Source, Target,
513         /*Naked*/ false));
514   else
515     DL->push_back(std::make_unique<PatternRewriteNamedAliasDescriptor>(
516         Source, Transform));
517 
518   return true;
519 }
520 
521 namespace {
522 
523 class RewriteSymbolsLegacyPass : public ModulePass {
524 public:
525   static char ID; // Pass identification, replacement for typeid
526 
527   RewriteSymbolsLegacyPass();
528   RewriteSymbolsLegacyPass(SymbolRewriter::RewriteDescriptorList &DL);
529 
530   bool runOnModule(Module &M) override;
531 
532 private:
533   RewriteSymbolPass Impl;
534 };
535 
536 } // end anonymous namespace
537 
538 char RewriteSymbolsLegacyPass::ID = 0;
539 
540 RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass() : ModulePass(ID) {
541   initializeRewriteSymbolsLegacyPassPass(*PassRegistry::getPassRegistry());
542 }
543 
544 RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass(
545     SymbolRewriter::RewriteDescriptorList &DL)
546     : ModulePass(ID), Impl(DL) {}
547 
548 bool RewriteSymbolsLegacyPass::runOnModule(Module &M) {
549   return Impl.runImpl(M);
550 }
551 
552 PreservedAnalyses RewriteSymbolPass::run(Module &M, ModuleAnalysisManager &AM) {
553   if (!runImpl(M))
554     return PreservedAnalyses::all();
555 
556   return PreservedAnalyses::none();
557 }
558 
559 bool RewriteSymbolPass::runImpl(Module &M) {
560   bool Changed;
561 
562   Changed = false;
563   for (auto &Descriptor : Descriptors)
564     Changed |= Descriptor->performOnModule(M);
565 
566   return Changed;
567 }
568 
569 void RewriteSymbolPass::loadAndParseMapFiles() {
570   const std::vector<std::string> MapFiles(RewriteMapFiles);
571   SymbolRewriter::RewriteMapParser Parser;
572 
573   for (const auto &MapFile : MapFiles)
574     Parser.parse(MapFile, &Descriptors);
575 }
576 
577 INITIALIZE_PASS(RewriteSymbolsLegacyPass, "rewrite-symbols", "Rewrite Symbols",
578                 false, false)
579 
580 ModulePass *llvm::createRewriteSymbolsPass() {
581   return new RewriteSymbolsLegacyPass();
582 }
583 
584 ModulePass *
585 llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) {
586   return new RewriteSymbolsLegacyPass(DL);
587 }
588