1 //===--- TestSupport.cpp - Clang-based refactoring tool -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements routines that provide refactoring testing
11 /// utilities.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "TestSupport.h"
16 #include "clang/Basic/DiagnosticError.h"
17 #include "clang/Basic/FileManager.h"
18 #include "clang/Basic/SourceManager.h"
19 #include "clang/Lex/Lexer.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/Error.h"
22 #include "llvm/Support/ErrorOr.h"
23 #include "llvm/Support/LineIterator.h"
24 #include "llvm/Support/MemoryBuffer.h"
25 #include "llvm/Support/Regex.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include <optional>
28 
29 using namespace llvm;
30 
31 namespace clang {
32 namespace refactor {
33 
34 void TestSelectionRangesInFile::dump(raw_ostream &OS) const {
35   for (const auto &Group : GroupedRanges) {
36     OS << "Test selection group '" << Group.Name << "':\n";
37     for (const auto &Range : Group.Ranges) {
38       OS << "  " << Range.Begin << "-" << Range.End << "\n";
39     }
40   }
41 }
42 
43 bool TestSelectionRangesInFile::foreachRange(
44     const SourceManager &SM,
45     llvm::function_ref<void(SourceRange)> Callback) const {
46   auto FE = SM.getFileManager().getFile(Filename);
47   FileID FID = FE ? SM.translateFile(*FE) : FileID();
48   if (!FE || FID.isInvalid()) {
49     llvm::errs() << "error: -selection=test:" << Filename
50                  << " : given file is not in the target TU";
51     return true;
52   }
53   SourceLocation FileLoc = SM.getLocForStartOfFile(FID);
54   for (const auto &Group : GroupedRanges) {
55     for (const TestSelectionRange &Range : Group.Ranges) {
56       // Translate the offset pair to a true source range.
57       SourceLocation Start =
58           SM.getMacroArgExpandedLocation(FileLoc.getLocWithOffset(Range.Begin));
59       SourceLocation End =
60           SM.getMacroArgExpandedLocation(FileLoc.getLocWithOffset(Range.End));
61       assert(Start.isValid() && End.isValid() && "unexpected invalid range");
62       Callback(SourceRange(Start, End));
63     }
64   }
65   return false;
66 }
67 
68 namespace {
69 
70 void dumpChanges(const tooling::AtomicChanges &Changes, raw_ostream &OS) {
71   for (const auto &Change : Changes)
72     OS << const_cast<tooling::AtomicChange &>(Change).toYAMLString() << "\n";
73 }
74 
75 bool areChangesSame(const tooling::AtomicChanges &LHS,
76                     const tooling::AtomicChanges &RHS) {
77   if (LHS.size() != RHS.size())
78     return false;
79   for (auto I : llvm::zip(LHS, RHS)) {
80     if (!(std::get<0>(I) == std::get<1>(I)))
81       return false;
82   }
83   return true;
84 }
85 
86 bool printRewrittenSources(const tooling::AtomicChanges &Changes,
87                            raw_ostream &OS) {
88   std::set<std::string> Files;
89   for (const auto &Change : Changes)
90     Files.insert(Change.getFilePath());
91   tooling::ApplyChangesSpec Spec;
92   Spec.Cleanup = false;
93   for (const auto &File : Files) {
94     llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> BufferErr =
95         llvm::MemoryBuffer::getFile(File);
96     if (!BufferErr) {
97       llvm::errs() << "failed to open" << File << "\n";
98       return true;
99     }
100     auto Result = tooling::applyAtomicChanges(File, (*BufferErr)->getBuffer(),
101                                               Changes, Spec);
102     if (!Result) {
103       llvm::errs() << toString(Result.takeError());
104       return true;
105     }
106     OS << *Result;
107   }
108   return false;
109 }
110 
111 class TestRefactoringResultConsumer final
112     : public ClangRefactorToolConsumerInterface {
113 public:
114   TestRefactoringResultConsumer(const TestSelectionRangesInFile &TestRanges)
115       : TestRanges(TestRanges) {
116     Results.push_back({});
117   }
118 
119   ~TestRefactoringResultConsumer() {
120     // Ensure all results are checked.
121     for (auto &Group : Results) {
122       for (auto &Result : Group) {
123         if (!Result) {
124           (void)llvm::toString(Result.takeError());
125         }
126       }
127     }
128   }
129 
130   void handleError(llvm::Error Err) override { handleResult(std::move(Err)); }
131 
132   void handle(tooling::AtomicChanges Changes) override {
133     handleResult(std::move(Changes));
134   }
135 
136   void handle(tooling::SymbolOccurrences Occurrences) override {
137     tooling::RefactoringResultConsumer::handle(std::move(Occurrences));
138   }
139 
140 private:
141   bool handleAllResults();
142 
143   void handleResult(Expected<tooling::AtomicChanges> Result) {
144     Results.back().push_back(std::move(Result));
145     size_t GroupIndex = Results.size() - 1;
146     if (Results.back().size() >=
147         TestRanges.GroupedRanges[GroupIndex].Ranges.size()) {
148       ++GroupIndex;
149       if (GroupIndex >= TestRanges.GroupedRanges.size()) {
150         if (handleAllResults())
151           exit(1); // error has occurred.
152         return;
153       }
154       Results.push_back({});
155     }
156   }
157 
158   const TestSelectionRangesInFile &TestRanges;
159   std::vector<std::vector<Expected<tooling::AtomicChanges>>> Results;
160 };
161 
162 std::pair<unsigned, unsigned> getLineColumn(StringRef Filename,
163                                             unsigned Offset) {
164   ErrorOr<std::unique_ptr<MemoryBuffer>> ErrOrFile =
165       MemoryBuffer::getFile(Filename);
166   if (!ErrOrFile)
167     return {0, 0};
168   StringRef Source = ErrOrFile.get()->getBuffer();
169   Source = Source.take_front(Offset);
170   size_t LastLine = Source.find_last_of("\r\n");
171   return {Source.count('\n') + 1,
172           (LastLine == StringRef::npos ? Offset : Offset - LastLine) + 1};
173 }
174 
175 } // end anonymous namespace
176 
177 bool TestRefactoringResultConsumer::handleAllResults() {
178   bool Failed = false;
179   for (auto &Group : llvm::enumerate(Results)) {
180     // All ranges in the group must produce the same result.
181     std::optional<tooling::AtomicChanges> CanonicalResult;
182     std::optional<std::string> CanonicalErrorMessage;
183     for (auto &I : llvm::enumerate(Group.value())) {
184       Expected<tooling::AtomicChanges> &Result = I.value();
185       std::string ErrorMessage;
186       bool HasResult = !!Result;
187       if (!HasResult) {
188         handleAllErrors(
189             Result.takeError(),
190             [&](StringError &Err) { ErrorMessage = Err.getMessage(); },
191             [&](DiagnosticError &Err) {
192               const PartialDiagnosticAt &Diag = Err.getDiagnostic();
193               llvm::SmallString<100> DiagText;
194               Diag.second.EmitToString(getDiags(), DiagText);
195               ErrorMessage = std::string(DiagText);
196             });
197       }
198       if (!CanonicalResult && !CanonicalErrorMessage) {
199         if (HasResult)
200           CanonicalResult = std::move(*Result);
201         else
202           CanonicalErrorMessage = std::move(ErrorMessage);
203         continue;
204       }
205 
206       // Verify that this result corresponds to the canonical result.
207       if (CanonicalErrorMessage) {
208         // The error messages must match.
209         if (!HasResult && ErrorMessage == *CanonicalErrorMessage)
210           continue;
211       } else {
212         assert(CanonicalResult && "missing canonical result");
213         // The results must match.
214         if (HasResult && areChangesSame(*Result, *CanonicalResult))
215           continue;
216       }
217       Failed = true;
218       // Report the mismatch.
219       std::pair<unsigned, unsigned> LineColumn = getLineColumn(
220           TestRanges.Filename,
221           TestRanges.GroupedRanges[Group.index()].Ranges[I.index()].Begin);
222       llvm::errs()
223           << "error: unexpected refactoring result for range starting at "
224           << LineColumn.first << ':' << LineColumn.second << " in group '"
225           << TestRanges.GroupedRanges[Group.index()].Name << "':\n  ";
226       if (HasResult)
227         llvm::errs() << "valid result";
228       else
229         llvm::errs() << "error '" << ErrorMessage << "'";
230       llvm::errs() << " does not match initial ";
231       if (CanonicalErrorMessage)
232         llvm::errs() << "error '" << *CanonicalErrorMessage << "'\n";
233       else
234         llvm::errs() << "valid result\n";
235       if (HasResult && !CanonicalErrorMessage) {
236         llvm::errs() << "  Expected to Produce:\n";
237         dumpChanges(*CanonicalResult, llvm::errs());
238         llvm::errs() << "  Produced:\n";
239         dumpChanges(*Result, llvm::errs());
240       }
241     }
242 
243     // Dump the results:
244     const auto &TestGroup = TestRanges.GroupedRanges[Group.index()];
245     if (!CanonicalResult) {
246       llvm::outs() << TestGroup.Ranges.size() << " '" << TestGroup.Name
247                    << "' results:\n";
248       llvm::outs() << *CanonicalErrorMessage << "\n";
249     } else {
250       llvm::outs() << TestGroup.Ranges.size() << " '" << TestGroup.Name
251                    << "' results:\n";
252       if (printRewrittenSources(*CanonicalResult, llvm::outs()))
253         return true;
254     }
255   }
256   return Failed;
257 }
258 
259 std::unique_ptr<ClangRefactorToolConsumerInterface>
260 TestSelectionRangesInFile::createConsumer() const {
261   return std::make_unique<TestRefactoringResultConsumer>(*this);
262 }
263 
264 /// Adds the \p ColumnOffset to file offset \p Offset, without going past a
265 /// newline.
266 static unsigned addColumnOffset(StringRef Source, unsigned Offset,
267                                 unsigned ColumnOffset) {
268   if (!ColumnOffset)
269     return Offset;
270   StringRef Substr = Source.drop_front(Offset).take_front(ColumnOffset);
271   size_t NewlinePos = Substr.find_first_of("\r\n");
272   return Offset +
273          (NewlinePos == StringRef::npos ? ColumnOffset : (unsigned)NewlinePos);
274 }
275 
276 static unsigned addEndLineOffsetAndEndColumn(StringRef Source, unsigned Offset,
277                                              unsigned LineNumberOffset,
278                                              unsigned Column) {
279   StringRef Line = Source.drop_front(Offset);
280   unsigned LineOffset = 0;
281   for (; LineNumberOffset != 0; --LineNumberOffset) {
282     size_t NewlinePos = Line.find_first_of("\r\n");
283     // Line offset goes out of bounds.
284     if (NewlinePos == StringRef::npos)
285       break;
286     LineOffset += NewlinePos + 1;
287     Line = Line.drop_front(NewlinePos + 1);
288   }
289   // Source now points to the line at +lineOffset;
290   size_t LineStart = Source.find_last_of("\r\n", /*From=*/Offset + LineOffset);
291   return addColumnOffset(
292       Source, LineStart == StringRef::npos ? 0 : LineStart + 1, Column - 1);
293 }
294 
295 std::optional<TestSelectionRangesInFile>
296 findTestSelectionRanges(StringRef Filename) {
297   ErrorOr<std::unique_ptr<MemoryBuffer>> ErrOrFile =
298       MemoryBuffer::getFile(Filename);
299   if (!ErrOrFile) {
300     llvm::errs() << "error: -selection=test:" << Filename
301                  << " : could not open the given file";
302     return std::nullopt;
303   }
304   StringRef Source = ErrOrFile.get()->getBuffer();
305 
306   // See the doc comment for this function for the explanation of this
307   // syntax.
308   static const Regex RangeRegex(
309       "range[[:blank:]]*([[:alpha:]_]*)?[[:blank:]]*=[[:"
310       "blank:]]*(\\+[[:digit:]]+)?[[:blank:]]*(->[[:blank:]"
311       "]*[\\+\\:[:digit:]]+)?");
312 
313   std::map<std::string, SmallVector<TestSelectionRange, 8>> GroupedRanges;
314 
315   LangOptions LangOpts;
316   LangOpts.CPlusPlus = 1;
317   LangOpts.CPlusPlus11 = 1;
318   Lexer Lex(SourceLocation::getFromRawEncoding(0), LangOpts, Source.begin(),
319             Source.begin(), Source.end());
320   Lex.SetCommentRetentionState(true);
321   Token Tok;
322   for (Lex.LexFromRawLexer(Tok); Tok.isNot(tok::eof);
323        Lex.LexFromRawLexer(Tok)) {
324     if (Tok.isNot(tok::comment))
325       continue;
326     StringRef Comment =
327         Source.substr(Tok.getLocation().getRawEncoding(), Tok.getLength());
328     SmallVector<StringRef, 4> Matches;
329     // Try to detect mistyped 'range:' comments to ensure tests don't miss
330     // anything.
331     auto DetectMistypedCommand = [&]() -> bool {
332       if (Comment.contains_insensitive("range") && Comment.contains("=") &&
333           !Comment.contains_insensitive("run") && !Comment.contains("CHECK")) {
334         llvm::errs() << "error: suspicious comment '" << Comment
335                      << "' that "
336                         "resembles the range command found\n";
337         llvm::errs() << "note: please reword if this isn't a range command\n";
338       }
339       return false;
340     };
341     // Allow CHECK: comments to contain range= commands.
342     if (!RangeRegex.match(Comment, &Matches) || Comment.contains("CHECK")) {
343       if (DetectMistypedCommand())
344         return std::nullopt;
345       continue;
346     }
347     unsigned Offset = Tok.getEndLoc().getRawEncoding();
348     unsigned ColumnOffset = 0;
349     if (!Matches[2].empty()) {
350       // Don't forget to drop the '+'!
351       if (Matches[2].drop_front().getAsInteger(10, ColumnOffset))
352         assert(false && "regex should have produced a number");
353     }
354     Offset = addColumnOffset(Source, Offset, ColumnOffset);
355     unsigned EndOffset;
356 
357     if (!Matches[3].empty()) {
358       static const Regex EndLocRegex(
359           "->[[:blank:]]*(\\+[[:digit:]]+):([[:digit:]]+)");
360       SmallVector<StringRef, 4> EndLocMatches;
361       if (!EndLocRegex.match(Matches[3], &EndLocMatches)) {
362         if (DetectMistypedCommand())
363           return std::nullopt;
364         continue;
365       }
366       unsigned EndLineOffset = 0, EndColumn = 0;
367       if (EndLocMatches[1].drop_front().getAsInteger(10, EndLineOffset) ||
368           EndLocMatches[2].getAsInteger(10, EndColumn))
369         assert(false && "regex should have produced a number");
370       EndOffset = addEndLineOffsetAndEndColumn(Source, Offset, EndLineOffset,
371                                                EndColumn);
372     } else {
373       EndOffset = Offset;
374     }
375     TestSelectionRange Range = {Offset, EndOffset};
376     auto It = GroupedRanges.insert(std::make_pair(
377         Matches[1].str(), SmallVector<TestSelectionRange, 8>{Range}));
378     if (!It.second)
379       It.first->second.push_back(Range);
380   }
381   if (GroupedRanges.empty()) {
382     llvm::errs() << "error: -selection=test:" << Filename
383                  << ": no 'range' commands";
384     return std::nullopt;
385   }
386 
387   TestSelectionRangesInFile TestRanges = {Filename.str(), {}};
388   for (auto &Group : GroupedRanges)
389     TestRanges.GroupedRanges.push_back({Group.first, std::move(Group.second)});
390   return std::move(TestRanges);
391 }
392 
393 } // end namespace refactor
394 } // end namespace clang
395