1 //===--- TestSupport.cpp - Clang-based refactoring tool -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements routines that provide refactoring testing
11 /// utilities.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "TestSupport.h"
16 #include "clang/Basic/DiagnosticError.h"
17 #include "clang/Basic/FileManager.h"
18 #include "clang/Basic/SourceManager.h"
19 #include "clang/Lex/Lexer.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/Error.h"
22 #include "llvm/Support/ErrorOr.h"
23 #include "llvm/Support/LineIterator.h"
24 #include "llvm/Support/MemoryBuffer.h"
25 #include "llvm/Support/Regex.h"
26 #include "llvm/Support/raw_ostream.h"
27 
28 using namespace llvm;
29 
30 namespace clang {
31 namespace refactor {
32 
dump(raw_ostream & OS) const33 void TestSelectionRangesInFile::dump(raw_ostream &OS) const {
34   for (const auto &Group : GroupedRanges) {
35     OS << "Test selection group '" << Group.Name << "':\n";
36     for (const auto &Range : Group.Ranges) {
37       OS << "  " << Range.Begin << "-" << Range.End << "\n";
38     }
39   }
40 }
41 
foreachRange(const SourceManager & SM,llvm::function_ref<void (SourceRange)> Callback) const42 bool TestSelectionRangesInFile::foreachRange(
43     const SourceManager &SM,
44     llvm::function_ref<void(SourceRange)> Callback) const {
45   auto FE = SM.getFileManager().getFile(Filename);
46   FileID FID = FE ? SM.translateFile(*FE) : FileID();
47   if (!FE || FID.isInvalid()) {
48     llvm::errs() << "error: -selection=test:" << Filename
49                  << " : given file is not in the target TU";
50     return true;
51   }
52   SourceLocation FileLoc = SM.getLocForStartOfFile(FID);
53   for (const auto &Group : GroupedRanges) {
54     for (const TestSelectionRange &Range : Group.Ranges) {
55       // Translate the offset pair to a true source range.
56       SourceLocation Start =
57           SM.getMacroArgExpandedLocation(FileLoc.getLocWithOffset(Range.Begin));
58       SourceLocation End =
59           SM.getMacroArgExpandedLocation(FileLoc.getLocWithOffset(Range.End));
60       assert(Start.isValid() && End.isValid() && "unexpected invalid range");
61       Callback(SourceRange(Start, End));
62     }
63   }
64   return false;
65 }
66 
67 namespace {
68 
dumpChanges(const tooling::AtomicChanges & Changes,raw_ostream & OS)69 void dumpChanges(const tooling::AtomicChanges &Changes, raw_ostream &OS) {
70   for (const auto &Change : Changes)
71     OS << const_cast<tooling::AtomicChange &>(Change).toYAMLString() << "\n";
72 }
73 
areChangesSame(const tooling::AtomicChanges & LHS,const tooling::AtomicChanges & RHS)74 bool areChangesSame(const tooling::AtomicChanges &LHS,
75                     const tooling::AtomicChanges &RHS) {
76   if (LHS.size() != RHS.size())
77     return false;
78   for (auto I : llvm::zip(LHS, RHS)) {
79     if (!(std::get<0>(I) == std::get<1>(I)))
80       return false;
81   }
82   return true;
83 }
84 
printRewrittenSources(const tooling::AtomicChanges & Changes,raw_ostream & OS)85 bool printRewrittenSources(const tooling::AtomicChanges &Changes,
86                            raw_ostream &OS) {
87   std::set<std::string> Files;
88   for (const auto &Change : Changes)
89     Files.insert(Change.getFilePath());
90   tooling::ApplyChangesSpec Spec;
91   Spec.Cleanup = false;
92   for (const auto &File : Files) {
93     llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> BufferErr =
94         llvm::MemoryBuffer::getFile(File);
95     if (!BufferErr) {
96       llvm::errs() << "failed to open" << File << "\n";
97       return true;
98     }
99     auto Result = tooling::applyAtomicChanges(File, (*BufferErr)->getBuffer(),
100                                               Changes, Spec);
101     if (!Result) {
102       llvm::errs() << toString(Result.takeError());
103       return true;
104     }
105     OS << *Result;
106   }
107   return false;
108 }
109 
110 class TestRefactoringResultConsumer final
111     : public ClangRefactorToolConsumerInterface {
112 public:
TestRefactoringResultConsumer(const TestSelectionRangesInFile & TestRanges)113   TestRefactoringResultConsumer(const TestSelectionRangesInFile &TestRanges)
114       : TestRanges(TestRanges) {
115     Results.push_back({});
116   }
117 
~TestRefactoringResultConsumer()118   ~TestRefactoringResultConsumer() {
119     // Ensure all results are checked.
120     for (auto &Group : Results) {
121       for (auto &Result : Group) {
122         if (!Result) {
123           (void)llvm::toString(Result.takeError());
124         }
125       }
126     }
127   }
128 
handleError(llvm::Error Err)129   void handleError(llvm::Error Err) override { handleResult(std::move(Err)); }
130 
handle(tooling::AtomicChanges Changes)131   void handle(tooling::AtomicChanges Changes) override {
132     handleResult(std::move(Changes));
133   }
134 
handle(tooling::SymbolOccurrences Occurrences)135   void handle(tooling::SymbolOccurrences Occurrences) override {
136     tooling::RefactoringResultConsumer::handle(std::move(Occurrences));
137   }
138 
139 private:
140   bool handleAllResults();
141 
handleResult(Expected<tooling::AtomicChanges> Result)142   void handleResult(Expected<tooling::AtomicChanges> Result) {
143     Results.back().push_back(std::move(Result));
144     size_t GroupIndex = Results.size() - 1;
145     if (Results.back().size() >=
146         TestRanges.GroupedRanges[GroupIndex].Ranges.size()) {
147       ++GroupIndex;
148       if (GroupIndex >= TestRanges.GroupedRanges.size()) {
149         if (handleAllResults())
150           exit(1); // error has occurred.
151         return;
152       }
153       Results.push_back({});
154     }
155   }
156 
157   const TestSelectionRangesInFile &TestRanges;
158   std::vector<std::vector<Expected<tooling::AtomicChanges>>> Results;
159 };
160 
getLineColumn(StringRef Filename,unsigned Offset)161 std::pair<unsigned, unsigned> getLineColumn(StringRef Filename,
162                                             unsigned Offset) {
163   ErrorOr<std::unique_ptr<MemoryBuffer>> ErrOrFile =
164       MemoryBuffer::getFile(Filename);
165   if (!ErrOrFile)
166     return {0, 0};
167   StringRef Source = ErrOrFile.get()->getBuffer();
168   Source = Source.take_front(Offset);
169   size_t LastLine = Source.find_last_of("\r\n");
170   return {Source.count('\n') + 1,
171           (LastLine == StringRef::npos ? Offset : Offset - LastLine) + 1};
172 }
173 
174 } // end anonymous namespace
175 
handleAllResults()176 bool TestRefactoringResultConsumer::handleAllResults() {
177   bool Failed = false;
178   for (auto &Group : llvm::enumerate(Results)) {
179     // All ranges in the group must produce the same result.
180     Optional<tooling::AtomicChanges> CanonicalResult;
181     Optional<std::string> CanonicalErrorMessage;
182     for (auto &I : llvm::enumerate(Group.value())) {
183       Expected<tooling::AtomicChanges> &Result = I.value();
184       std::string ErrorMessage;
185       bool HasResult = !!Result;
186       if (!HasResult) {
187         handleAllErrors(
188             Result.takeError(),
189             [&](StringError &Err) { ErrorMessage = Err.getMessage(); },
190             [&](DiagnosticError &Err) {
191               const PartialDiagnosticAt &Diag = Err.getDiagnostic();
192               llvm::SmallString<100> DiagText;
193               Diag.second.EmitToString(getDiags(), DiagText);
194               ErrorMessage = std::string(DiagText);
195             });
196       }
197       if (!CanonicalResult && !CanonicalErrorMessage) {
198         if (HasResult)
199           CanonicalResult = std::move(*Result);
200         else
201           CanonicalErrorMessage = std::move(ErrorMessage);
202         continue;
203       }
204 
205       // Verify that this result corresponds to the canonical result.
206       if (CanonicalErrorMessage) {
207         // The error messages must match.
208         if (!HasResult && ErrorMessage == *CanonicalErrorMessage)
209           continue;
210       } else {
211         assert(CanonicalResult && "missing canonical result");
212         // The results must match.
213         if (HasResult && areChangesSame(*Result, *CanonicalResult))
214           continue;
215       }
216       Failed = true;
217       // Report the mismatch.
218       std::pair<unsigned, unsigned> LineColumn = getLineColumn(
219           TestRanges.Filename,
220           TestRanges.GroupedRanges[Group.index()].Ranges[I.index()].Begin);
221       llvm::errs()
222           << "error: unexpected refactoring result for range starting at "
223           << LineColumn.first << ':' << LineColumn.second << " in group '"
224           << TestRanges.GroupedRanges[Group.index()].Name << "':\n  ";
225       if (HasResult)
226         llvm::errs() << "valid result";
227       else
228         llvm::errs() << "error '" << ErrorMessage << "'";
229       llvm::errs() << " does not match initial ";
230       if (CanonicalErrorMessage)
231         llvm::errs() << "error '" << *CanonicalErrorMessage << "'\n";
232       else
233         llvm::errs() << "valid result\n";
234       if (HasResult && !CanonicalErrorMessage) {
235         llvm::errs() << "  Expected to Produce:\n";
236         dumpChanges(*CanonicalResult, llvm::errs());
237         llvm::errs() << "  Produced:\n";
238         dumpChanges(*Result, llvm::errs());
239       }
240     }
241 
242     // Dump the results:
243     const auto &TestGroup = TestRanges.GroupedRanges[Group.index()];
244     if (!CanonicalResult) {
245       llvm::outs() << TestGroup.Ranges.size() << " '" << TestGroup.Name
246                    << "' results:\n";
247       llvm::outs() << *CanonicalErrorMessage << "\n";
248     } else {
249       llvm::outs() << TestGroup.Ranges.size() << " '" << TestGroup.Name
250                    << "' results:\n";
251       if (printRewrittenSources(*CanonicalResult, llvm::outs()))
252         return true;
253     }
254   }
255   return Failed;
256 }
257 
258 std::unique_ptr<ClangRefactorToolConsumerInterface>
createConsumer() const259 TestSelectionRangesInFile::createConsumer() const {
260   return std::make_unique<TestRefactoringResultConsumer>(*this);
261 }
262 
263 /// Adds the \p ColumnOffset to file offset \p Offset, without going past a
264 /// newline.
addColumnOffset(StringRef Source,unsigned Offset,unsigned ColumnOffset)265 static unsigned addColumnOffset(StringRef Source, unsigned Offset,
266                                 unsigned ColumnOffset) {
267   if (!ColumnOffset)
268     return Offset;
269   StringRef Substr = Source.drop_front(Offset).take_front(ColumnOffset);
270   size_t NewlinePos = Substr.find_first_of("\r\n");
271   return Offset +
272          (NewlinePos == StringRef::npos ? ColumnOffset : (unsigned)NewlinePos);
273 }
274 
addEndLineOffsetAndEndColumn(StringRef Source,unsigned Offset,unsigned LineNumberOffset,unsigned Column)275 static unsigned addEndLineOffsetAndEndColumn(StringRef Source, unsigned Offset,
276                                              unsigned LineNumberOffset,
277                                              unsigned Column) {
278   StringRef Line = Source.drop_front(Offset);
279   unsigned LineOffset = 0;
280   for (; LineNumberOffset != 0; --LineNumberOffset) {
281     size_t NewlinePos = Line.find_first_of("\r\n");
282     // Line offset goes out of bounds.
283     if (NewlinePos == StringRef::npos)
284       break;
285     LineOffset += NewlinePos + 1;
286     Line = Line.drop_front(NewlinePos + 1);
287   }
288   // Source now points to the line at +lineOffset;
289   size_t LineStart = Source.find_last_of("\r\n", /*From=*/Offset + LineOffset);
290   return addColumnOffset(
291       Source, LineStart == StringRef::npos ? 0 : LineStart + 1, Column - 1);
292 }
293 
294 Optional<TestSelectionRangesInFile>
findTestSelectionRanges(StringRef Filename)295 findTestSelectionRanges(StringRef Filename) {
296   ErrorOr<std::unique_ptr<MemoryBuffer>> ErrOrFile =
297       MemoryBuffer::getFile(Filename);
298   if (!ErrOrFile) {
299     llvm::errs() << "error: -selection=test:" << Filename
300                  << " : could not open the given file";
301     return None;
302   }
303   StringRef Source = ErrOrFile.get()->getBuffer();
304 
305   // See the doc comment for this function for the explanation of this
306   // syntax.
307   static const Regex RangeRegex(
308       "range[[:blank:]]*([[:alpha:]_]*)?[[:blank:]]*=[[:"
309       "blank:]]*(\\+[[:digit:]]+)?[[:blank:]]*(->[[:blank:]"
310       "]*[\\+\\:[:digit:]]+)?");
311 
312   std::map<std::string, SmallVector<TestSelectionRange, 8>> GroupedRanges;
313 
314   LangOptions LangOpts;
315   LangOpts.CPlusPlus = 1;
316   LangOpts.CPlusPlus11 = 1;
317   Lexer Lex(SourceLocation::getFromRawEncoding(0), LangOpts, Source.begin(),
318             Source.begin(), Source.end());
319   Lex.SetCommentRetentionState(true);
320   Token Tok;
321   for (Lex.LexFromRawLexer(Tok); Tok.isNot(tok::eof);
322        Lex.LexFromRawLexer(Tok)) {
323     if (Tok.isNot(tok::comment))
324       continue;
325     StringRef Comment =
326         Source.substr(Tok.getLocation().getRawEncoding(), Tok.getLength());
327     SmallVector<StringRef, 4> Matches;
328     // Try to detect mistyped 'range:' comments to ensure tests don't miss
329     // anything.
330     auto DetectMistypedCommand = [&]() -> bool {
331       if (Comment.contains_insensitive("range") && Comment.contains("=") &&
332           !Comment.contains_insensitive("run") && !Comment.contains("CHECK")) {
333         llvm::errs() << "error: suspicious comment '" << Comment
334                      << "' that "
335                         "resembles the range command found\n";
336         llvm::errs() << "note: please reword if this isn't a range command\n";
337       }
338       return false;
339     };
340     // Allow CHECK: comments to contain range= commands.
341     if (!RangeRegex.match(Comment, &Matches) || Comment.contains("CHECK")) {
342       if (DetectMistypedCommand())
343         return None;
344       continue;
345     }
346     unsigned Offset = Tok.getEndLoc().getRawEncoding();
347     unsigned ColumnOffset = 0;
348     if (!Matches[2].empty()) {
349       // Don't forget to drop the '+'!
350       if (Matches[2].drop_front().getAsInteger(10, ColumnOffset))
351         assert(false && "regex should have produced a number");
352     }
353     Offset = addColumnOffset(Source, Offset, ColumnOffset);
354     unsigned EndOffset;
355 
356     if (!Matches[3].empty()) {
357       static const Regex EndLocRegex(
358           "->[[:blank:]]*(\\+[[:digit:]]+):([[:digit:]]+)");
359       SmallVector<StringRef, 4> EndLocMatches;
360       if (!EndLocRegex.match(Matches[3], &EndLocMatches)) {
361         if (DetectMistypedCommand())
362           return None;
363         continue;
364       }
365       unsigned EndLineOffset = 0, EndColumn = 0;
366       if (EndLocMatches[1].drop_front().getAsInteger(10, EndLineOffset) ||
367           EndLocMatches[2].getAsInteger(10, EndColumn))
368         assert(false && "regex should have produced a number");
369       EndOffset = addEndLineOffsetAndEndColumn(Source, Offset, EndLineOffset,
370                                                EndColumn);
371     } else {
372       EndOffset = Offset;
373     }
374     TestSelectionRange Range = {Offset, EndOffset};
375     auto It = GroupedRanges.insert(std::make_pair(
376         Matches[1].str(), SmallVector<TestSelectionRange, 8>{Range}));
377     if (!It.second)
378       It.first->second.push_back(Range);
379   }
380   if (GroupedRanges.empty()) {
381     llvm::errs() << "error: -selection=test:" << Filename
382                  << ": no 'range' commands";
383     return None;
384   }
385 
386   TestSelectionRangesInFile TestRanges = {Filename.str(), {}};
387   for (auto &Group : GroupedRanges)
388     TestRanges.GroupedRanges.push_back({Group.first, std::move(Group.second)});
389   return std::move(TestRanges);
390 }
391 
392 } // end namespace refactor
393 } // end namespace clang
394