1 //===--- RefactoringCallbacks.cpp - Structural query framework ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //
10 //===----------------------------------------------------------------------===//
11 #include "clang/Tooling/RefactoringCallbacks.h"
12 #include "clang/ASTMatchers/ASTMatchFinder.h"
13 #include "clang/Basic/SourceLocation.h"
14 #include "clang/Lex/Lexer.h"
15 
16 using llvm::StringError;
17 using llvm::make_error;
18 
19 namespace clang {
20 namespace tooling {
21 
22 RefactoringCallback::RefactoringCallback() {}
23 tooling::Replacements &RefactoringCallback::getReplacements() {
24   return Replace;
25 }
26 
27 ASTMatchRefactorer::ASTMatchRefactorer(
28     std::map<std::string, Replacements> &FileToReplaces)
29     : FileToReplaces(FileToReplaces) {}
30 
31 void ASTMatchRefactorer::addDynamicMatcher(
32     const ast_matchers::internal::DynTypedMatcher &Matcher,
33     RefactoringCallback *Callback) {
34   MatchFinder.addDynamicMatcher(Matcher, Callback);
35   Callbacks.push_back(Callback);
36 }
37 
38 class RefactoringASTConsumer : public ASTConsumer {
39 public:
40   explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring)
41       : Refactoring(Refactoring) {}
42 
43   void HandleTranslationUnit(ASTContext &Context) override {
44     // The ASTMatchRefactorer is re-used between translation units.
45     // Clear the matchers so that each Replacement is only emitted once.
46     for (const auto &Callback : Refactoring.Callbacks) {
47       Callback->getReplacements().clear();
48     }
49     Refactoring.MatchFinder.matchAST(Context);
50     for (const auto &Callback : Refactoring.Callbacks) {
51       for (const auto &Replacement : Callback->getReplacements()) {
52         llvm::Error Err =
53             Refactoring.FileToReplaces[Replacement.getFilePath()].add(
54                 Replacement);
55         if (Err) {
56           llvm::errs() << "Skipping replacement " << Replacement.toString()
57                        << " due to this error:\n"
58                        << toString(std::move(Err)) << "\n";
59         }
60       }
61     }
62   }
63 
64 private:
65   ASTMatchRefactorer &Refactoring;
66 };
67 
68 std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() {
69   return std::make_unique<RefactoringASTConsumer>(*this);
70 }
71 
72 static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From,
73                                        StringRef Text) {
74   return tooling::Replacement(
75       Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text);
76 }
77 static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From,
78                                        const Stmt &To) {
79   return replaceStmtWithText(
80       Sources, From,
81       Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()),
82                            Sources, LangOptions()));
83 }
84 
85 ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText)
86     : FromId(FromId), ToText(ToText) {}
87 
88 void ReplaceStmtWithText::run(
89     const ast_matchers::MatchFinder::MatchResult &Result) {
90   if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId)) {
91     auto Err = Replace.add(tooling::Replacement(
92         *Result.SourceManager,
93         CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText));
94     // FIXME: better error handling. For now, just print error message in the
95     // release version.
96     if (Err) {
97       llvm::errs() << llvm::toString(std::move(Err)) << "\n";
98       assert(false);
99     }
100   }
101 }
102 
103 ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId)
104     : FromId(FromId), ToId(ToId) {}
105 
106 void ReplaceStmtWithStmt::run(
107     const ast_matchers::MatchFinder::MatchResult &Result) {
108   const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId);
109   const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ToId);
110   if (FromMatch && ToMatch) {
111     auto Err = Replace.add(
112         replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch));
113     // FIXME: better error handling. For now, just print error message in the
114     // release version.
115     if (Err) {
116       llvm::errs() << llvm::toString(std::move(Err)) << "\n";
117       assert(false);
118     }
119   }
120 }
121 
122 ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id,
123                                                    bool PickTrueBranch)
124     : Id(Id), PickTrueBranch(PickTrueBranch) {}
125 
126 void ReplaceIfStmtWithItsBody::run(
127     const ast_matchers::MatchFinder::MatchResult &Result) {
128   if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(Id)) {
129     const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse();
130     if (Body) {
131       auto Err =
132           Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body));
133       // FIXME: better error handling. For now, just print error message in the
134       // release version.
135       if (Err) {
136         llvm::errs() << llvm::toString(std::move(Err)) << "\n";
137         assert(false);
138       }
139     } else if (!PickTrueBranch) {
140       // If we want to use the 'else'-branch, but it doesn't exist, delete
141       // the whole 'if'.
142       auto Err =
143           Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, ""));
144       // FIXME: better error handling. For now, just print error message in the
145       // release version.
146       if (Err) {
147         llvm::errs() << llvm::toString(std::move(Err)) << "\n";
148         assert(false);
149       }
150     }
151   }
152 }
153 
154 ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(
155     llvm::StringRef FromId, std::vector<TemplateElement> Template)
156     : FromId(FromId), Template(std::move(Template)) {}
157 
158 llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
159 ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) {
160   std::vector<TemplateElement> ParsedTemplate;
161   for (size_t Index = 0; Index < ToTemplate.size();) {
162     if (ToTemplate[Index] == '$') {
163       if (ToTemplate.substr(Index, 2) == "$$") {
164         Index += 2;
165         ParsedTemplate.push_back(
166             TemplateElement{TemplateElement::Literal, "$"});
167       } else if (ToTemplate.substr(Index, 2) == "${") {
168         size_t EndOfIdentifier = ToTemplate.find("}", Index);
169         if (EndOfIdentifier == std::string::npos) {
170           return make_error<StringError>(
171               "Unterminated ${...} in replacement template near " +
172                   ToTemplate.substr(Index),
173               llvm::inconvertibleErrorCode());
174         }
175         std::string SourceNodeName =
176             ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2);
177         ParsedTemplate.push_back(
178             TemplateElement{TemplateElement::Identifier, SourceNodeName});
179         Index = EndOfIdentifier + 1;
180       } else {
181         return make_error<StringError>(
182             "Invalid $ in replacement template near " +
183                 ToTemplate.substr(Index),
184             llvm::inconvertibleErrorCode());
185       }
186     } else {
187       size_t NextIndex = ToTemplate.find('$', Index + 1);
188       ParsedTemplate.push_back(
189           TemplateElement{TemplateElement::Literal,
190                           ToTemplate.substr(Index, NextIndex - Index)});
191       Index = NextIndex;
192     }
193   }
194   return std::unique_ptr<ReplaceNodeWithTemplate>(
195       new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate)));
196 }
197 
198 void ReplaceNodeWithTemplate::run(
199     const ast_matchers::MatchFinder::MatchResult &Result) {
200   const auto &NodeMap = Result.Nodes.getMap();
201 
202   std::string ToText;
203   for (const auto &Element : Template) {
204     switch (Element.Type) {
205     case TemplateElement::Literal:
206       ToText += Element.Value;
207       break;
208     case TemplateElement::Identifier: {
209       auto NodeIter = NodeMap.find(Element.Value);
210       if (NodeIter == NodeMap.end()) {
211         llvm::errs() << "Node " << Element.Value
212                      << " used in replacement template not bound in Matcher \n";
213         llvm::report_fatal_error("Unbound node in replacement template.");
214       }
215       CharSourceRange Source =
216           CharSourceRange::getTokenRange(NodeIter->second.getSourceRange());
217       ToText += Lexer::getSourceText(Source, *Result.SourceManager,
218                                      Result.Context->getLangOpts());
219       break;
220     }
221     }
222   }
223   if (NodeMap.count(FromId) == 0) {
224     llvm::errs() << "Node to be replaced " << FromId
225                  << " not bound in query.\n";
226     llvm::report_fatal_error("FromId node not bound in MatchResult");
227   }
228   auto Replacement =
229       tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText,
230                            Result.Context->getLangOpts());
231   llvm::Error Err = Replace.add(Replacement);
232   if (Err) {
233     llvm::errs() << "Query and replace failed in " << Replacement.getFilePath()
234                  << "! " << llvm::toString(std::move(Err)) << "\n";
235     llvm::report_fatal_error("Replacement failed");
236   }
237 }
238 
239 } // end namespace tooling
240 } // end namespace clang
241