1 //===- ClangDiff.cpp - compare source files by AST nodes ------*- C++ -*- -===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a tool for syntax tree based comparison using
11 // Tooling/ASTDiff.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "clang/Tooling/ASTDiff/ASTDiff.h"
16 #include "clang/Tooling/CommonOptionsParser.h"
17 #include "clang/Tooling/Tooling.h"
18 #include "llvm/Support/CommandLine.h"
19 
20 using namespace llvm;
21 using namespace clang;
22 using namespace clang::tooling;
23 
24 static cl::OptionCategory ClangDiffCategory("clang-diff options");
25 
26 static cl::opt<bool>
27     ASTDump("ast-dump",
28             cl::desc("Print the internal representation of the AST."),
29             cl::init(false), cl::cat(ClangDiffCategory));
30 
31 static cl::opt<bool> ASTDumpJson(
32     "ast-dump-json",
33     cl::desc("Print the internal representation of the AST as JSON."),
34     cl::init(false), cl::cat(ClangDiffCategory));
35 
36 static cl::opt<bool> PrintMatches("dump-matches",
37                                   cl::desc("Print the matched nodes."),
38                                   cl::init(false), cl::cat(ClangDiffCategory));
39 
40 static cl::opt<bool> HtmlDiff("html",
41                               cl::desc("Output a side-by-side diff in HTML."),
42                               cl::init(false), cl::cat(ClangDiffCategory));
43 
44 static cl::opt<std::string> SourcePath(cl::Positional, cl::desc("<source>"),
45                                        cl::Required,
46                                        cl::cat(ClangDiffCategory));
47 
48 static cl::opt<std::string> DestinationPath(cl::Positional,
49                                             cl::desc("<destination>"),
50                                             cl::Optional,
51                                             cl::cat(ClangDiffCategory));
52 
53 static cl::opt<std::string> StopAfter("stop-diff-after",
54                                       cl::desc("<topdown|bottomup>"),
55                                       cl::Optional, cl::init(""),
56                                       cl::cat(ClangDiffCategory));
57 
58 static cl::opt<int> MaxSize("s", cl::desc("<maxsize>"), cl::Optional,
59                             cl::init(-1), cl::cat(ClangDiffCategory));
60 
61 static cl::opt<std::string> BuildPath("p", cl::desc("Build path"), cl::init(""),
62                                       cl::Optional, cl::cat(ClangDiffCategory));
63 
64 static cl::list<std::string> ArgsAfter(
65     "extra-arg",
66     cl::desc("Additional argument to append to the compiler command line"),
67     cl::cat(ClangDiffCategory));
68 
69 static cl::list<std::string> ArgsBefore(
70     "extra-arg-before",
71     cl::desc("Additional argument to prepend to the compiler command line"),
72     cl::cat(ClangDiffCategory));
73 
addExtraArgs(std::unique_ptr<CompilationDatabase> & Compilations)74 static void addExtraArgs(std::unique_ptr<CompilationDatabase> &Compilations) {
75   if (!Compilations)
76     return;
77   auto AdjustingCompilations =
78       llvm::make_unique<ArgumentsAdjustingCompilations>(
79           std::move(Compilations));
80   AdjustingCompilations->appendArgumentsAdjuster(
81       getInsertArgumentAdjuster(ArgsBefore, ArgumentInsertPosition::BEGIN));
82   AdjustingCompilations->appendArgumentsAdjuster(
83       getInsertArgumentAdjuster(ArgsAfter, ArgumentInsertPosition::END));
84   Compilations = std::move(AdjustingCompilations);
85 }
86 
87 static std::unique_ptr<ASTUnit>
getAST(const std::unique_ptr<CompilationDatabase> & CommonCompilations,const StringRef Filename)88 getAST(const std::unique_ptr<CompilationDatabase> &CommonCompilations,
89        const StringRef Filename) {
90   std::string ErrorMessage;
91   std::unique_ptr<CompilationDatabase> Compilations;
92   if (!CommonCompilations) {
93     Compilations = CompilationDatabase::autoDetectFromSource(
94         BuildPath.empty() ? Filename : BuildPath, ErrorMessage);
95     if (!Compilations) {
96       llvm::errs()
97           << "Error while trying to load a compilation database, running "
98              "without flags.\n"
99           << ErrorMessage;
100       Compilations =
101           llvm::make_unique<clang::tooling::FixedCompilationDatabase>(
102               ".", std::vector<std::string>());
103     }
104   }
105   addExtraArgs(Compilations);
106   std::array<std::string, 1> Files = {{Filename}};
107   ClangTool Tool(Compilations ? *Compilations : *CommonCompilations, Files);
108   std::vector<std::unique_ptr<ASTUnit>> ASTs;
109   Tool.buildASTs(ASTs);
110   if (ASTs.size() != Files.size())
111     return nullptr;
112   return std::move(ASTs[0]);
113 }
114 
hexdigit(int N)115 static char hexdigit(int N) { return N &= 0xf, N + (N < 10 ? '0' : 'a' - 10); }
116 
117 static const char HtmlDiffHeader[] = R"(
118 <html>
119 <head>
120 <meta charset='utf-8'/>
121 <style>
122 span.d { color: red; }
123 span.u { color: #cc00cc; }
124 span.i { color: green; }
125 span.m { font-weight: bold; }
126 span   { font-weight: normal; color: black; }
127 div.code {
128   width: 48%;
129   height: 98%;
130   overflow: scroll;
131   float: left;
132   padding: 0 0 0.5% 0.5%;
133   border: solid 2px LightGrey;
134   border-radius: 5px;
135 }
136 </style>
137 </head>
138 <script type='text/javascript'>
139 highlightStack = []
140 function clearHighlight() {
141   while (highlightStack.length) {
142     var [l, r] = highlightStack.pop()
143     document.getElementById(l).style.backgroundColor = 'inherit'
144     if (r[1] != '-')
145       document.getElementById(r).style.backgroundColor = 'inherit'
146   }
147 }
148 function highlight(event) {
149   var id = event.target['id']
150   doHighlight(id)
151 }
152 function doHighlight(id) {
153   clearHighlight()
154   source = document.getElementById(id)
155   if (!source.attributes['tid'])
156     return
157   var mapped = source
158   while (mapped && mapped.parentElement && mapped.attributes['tid'].value.substr(1) === '-1')
159     mapped = mapped.parentElement
160   var tid = null, target = null
161   if (mapped) {
162     tid = mapped.attributes['tid'].value
163     target = document.getElementById(tid)
164   }
165   if (source.parentElement && source.parentElement.classList.contains('code'))
166     return
167   source.style.backgroundColor = 'lightgrey'
168   source.scrollIntoView()
169   if (target) {
170     if (mapped === source)
171       target.style.backgroundColor = 'lightgrey'
172     target.scrollIntoView()
173   }
174   highlightStack.push([id, tid])
175   location.hash = '#' + id
176 }
177 function scrollToBoth() {
178   doHighlight(location.hash.substr(1))
179 }
180 function changed(elem) {
181   return elem.classList.length == 0
182 }
183 function nextChangedNode(prefix, increment, number) {
184   do {
185     number += increment
186     var elem = document.getElementById(prefix + number)
187   } while(elem && !changed(elem))
188   return elem ? number : null
189 }
190 function handleKey(e) {
191   var down = e.code === "KeyJ"
192   var up = e.code === "KeyK"
193   if (!down && !up)
194     return
195   var id = highlightStack[0] ? highlightStack[0][0] : 'R0'
196   var oldelem = document.getElementById(id)
197   var number = parseInt(id.substr(1))
198   var increment = down ? 1 : -1
199   var lastnumber = number
200   var prefix = id[0]
201   do {
202     number = nextChangedNode(prefix, increment, number)
203     var elem = document.getElementById(prefix + number)
204     if (up && elem) {
205       while (elem.parentElement && changed(elem.parentElement))
206         elem = elem.parentElement
207       number = elem.id.substr(1)
208     }
209   } while ((down && id !== 'R0' && oldelem.contains(elem)))
210   if (!number)
211     number = lastnumber
212   elem = document.getElementById(prefix + number)
213   doHighlight(prefix + number)
214 }
215 window.onload = scrollToBoth
216 window.onkeydown = handleKey
217 </script>
218 <body>
219 <div onclick='highlight(event)'>
220 )";
221 
printHtml(raw_ostream & OS,char C)222 static void printHtml(raw_ostream &OS, char C) {
223   switch (C) {
224   case '&':
225     OS << "&amp;";
226     break;
227   case '<':
228     OS << "&lt;";
229     break;
230   case '>':
231     OS << "&gt;";
232     break;
233   case '\'':
234     OS << "&#x27;";
235     break;
236   case '"':
237     OS << "&quot;";
238     break;
239   default:
240     OS << C;
241   }
242 }
243 
printHtml(raw_ostream & OS,const StringRef Str)244 static void printHtml(raw_ostream &OS, const StringRef Str) {
245   for (char C : Str)
246     printHtml(OS, C);
247 }
248 
getChangeKindAbbr(diff::ChangeKind Kind)249 static std::string getChangeKindAbbr(diff::ChangeKind Kind) {
250   switch (Kind) {
251   case diff::None:
252     return "";
253   case diff::Delete:
254     return "d";
255   case diff::Update:
256     return "u";
257   case diff::Insert:
258     return "i";
259   case diff::Move:
260     return "m";
261   case diff::UpdateMove:
262     return "u m";
263   }
264   llvm_unreachable("Invalid enumeration value.");
265 }
266 
printHtmlForNode(raw_ostream & OS,const diff::ASTDiff & Diff,diff::SyntaxTree & Tree,bool IsLeft,diff::NodeId Id,unsigned Offset)267 static unsigned printHtmlForNode(raw_ostream &OS, const diff::ASTDiff &Diff,
268                                  diff::SyntaxTree &Tree, bool IsLeft,
269                                  diff::NodeId Id, unsigned Offset) {
270   const diff::Node &Node = Tree.getNode(Id);
271   char MyTag, OtherTag;
272   diff::NodeId LeftId, RightId;
273   diff::NodeId TargetId = Diff.getMapped(Tree, Id);
274   if (IsLeft) {
275     MyTag = 'L';
276     OtherTag = 'R';
277     LeftId = Id;
278     RightId = TargetId;
279   } else {
280     MyTag = 'R';
281     OtherTag = 'L';
282     LeftId = TargetId;
283     RightId = Id;
284   }
285   unsigned Begin, End;
286   std::tie(Begin, End) = Tree.getSourceRangeOffsets(Node);
287   const SourceManager &SrcMgr = Tree.getASTContext().getSourceManager();
288   auto Code = SrcMgr.getBuffer(SrcMgr.getMainFileID())->getBuffer();
289   for (; Offset < Begin; ++Offset)
290     printHtml(OS, Code[Offset]);
291   OS << "<span id='" << MyTag << Id << "' "
292      << "tid='" << OtherTag << TargetId << "' ";
293   OS << "title='";
294   printHtml(OS, Node.getTypeLabel());
295   OS << "\n" << LeftId << " -> " << RightId;
296   std::string Value = Tree.getNodeValue(Node);
297   if (!Value.empty()) {
298     OS << "\n";
299     printHtml(OS, Value);
300   }
301   OS << "'";
302   if (Node.Change != diff::None)
303     OS << " class='" << getChangeKindAbbr(Node.Change) << "'";
304   OS << ">";
305 
306   for (diff::NodeId Child : Node.Children)
307     Offset = printHtmlForNode(OS, Diff, Tree, IsLeft, Child, Offset);
308 
309   for (; Offset < End; ++Offset)
310     printHtml(OS, Code[Offset]);
311   if (Id == Tree.getRootId()) {
312     End = Code.size();
313     for (; Offset < End; ++Offset)
314       printHtml(OS, Code[Offset]);
315   }
316   OS << "</span>";
317   return Offset;
318 }
319 
printJsonString(raw_ostream & OS,const StringRef Str)320 static void printJsonString(raw_ostream &OS, const StringRef Str) {
321   for (signed char C : Str) {
322     switch (C) {
323     case '"':
324       OS << R"(\")";
325       break;
326     case '\\':
327       OS << R"(\\)";
328       break;
329     case '\n':
330       OS << R"(\n)";
331       break;
332     case '\t':
333       OS << R"(\t)";
334       break;
335     default:
336       if ('\x00' <= C && C <= '\x1f') {
337         OS << R"(\u00)" << hexdigit(C >> 4) << hexdigit(C);
338       } else {
339         OS << C;
340       }
341     }
342   }
343 }
344 
printNodeAttributes(raw_ostream & OS,diff::SyntaxTree & Tree,diff::NodeId Id)345 static void printNodeAttributes(raw_ostream &OS, diff::SyntaxTree &Tree,
346                                 diff::NodeId Id) {
347   const diff::Node &N = Tree.getNode(Id);
348   OS << R"("id":)" << int(Id);
349   OS << R"(,"type":")" << N.getTypeLabel() << '"';
350   auto Offsets = Tree.getSourceRangeOffsets(N);
351   OS << R"(,"begin":)" << Offsets.first;
352   OS << R"(,"end":)" << Offsets.second;
353   std::string Value = Tree.getNodeValue(N);
354   if (!Value.empty()) {
355     OS << R"(,"value":")";
356     printJsonString(OS, Value);
357     OS << '"';
358   }
359 }
360 
361 static void printNodeAsJson(raw_ostream &OS, diff::SyntaxTree &Tree,
362                             diff::NodeId Id) {
363   const diff::Node &N = Tree.getNode(Id);
364   OS << "{";
365   printNodeAttributes(OS, Tree, Id);
366   auto Identifier = N.getIdentifier();
367   auto QualifiedIdentifier = N.getQualifiedIdentifier();
368   if (Identifier) {
369     OS << R"(,"identifier":")";
370     printJsonString(OS, *Identifier);
371     OS << R"(")";
372     if (QualifiedIdentifier && *Identifier != *QualifiedIdentifier) {
373       OS << R"(,"qualified_identifier":")";
374       printJsonString(OS, *QualifiedIdentifier);
375       OS << R"(")";
376     }
377   }
378   OS << R"(,"children":[)";
379   if (N.Children.size() > 0) {
380     printNodeAsJson(OS, Tree, N.Children[0]);
381     for (size_t I = 1, E = N.Children.size(); I < E; ++I) {
382       OS << ",";
383       printNodeAsJson(OS, Tree, N.Children[I]);
384     }
385   }
386   OS << "]}";
387 }
388 
389 static void printNode(raw_ostream &OS, diff::SyntaxTree &Tree,
390                       diff::NodeId Id) {
391   if (Id.isInvalid()) {
392     OS << "None";
393     return;
394   }
395   OS << Tree.getNode(Id).getTypeLabel();
396   std::string Value = Tree.getNodeValue(Id);
397   if (!Value.empty())
398     OS << ": " << Value;
399   OS << "(" << Id << ")";
400 }
401 
402 static void printTree(raw_ostream &OS, diff::SyntaxTree &Tree) {
403   for (diff::NodeId Id : Tree) {
404     for (int I = 0; I < Tree.getNode(Id).Depth; ++I)
405       OS << " ";
406     printNode(OS, Tree, Id);
407     OS << "\n";
408   }
409 }
410 
411 static void printDstChange(raw_ostream &OS, diff::ASTDiff &Diff,
412                            diff::SyntaxTree &SrcTree, diff::SyntaxTree &DstTree,
413                            diff::NodeId Dst) {
414   const diff::Node &DstNode = DstTree.getNode(Dst);
415   diff::NodeId Src = Diff.getMapped(DstTree, Dst);
416   switch (DstNode.Change) {
417   case diff::None:
418     break;
419   case diff::Delete:
420     llvm_unreachable("The destination tree can't have deletions.");
421   case diff::Update:
422     OS << "Update ";
423     printNode(OS, SrcTree, Src);
424     OS << " to " << DstTree.getNodeValue(Dst) << "\n";
425     break;
426   case diff::Insert:
427   case diff::Move:
428   case diff::UpdateMove:
429     if (DstNode.Change == diff::Insert)
430       OS << "Insert";
431     else if (DstNode.Change == diff::Move)
432       OS << "Move";
433     else if (DstNode.Change == diff::UpdateMove)
434       OS << "Update and Move";
435     OS << " ";
436     printNode(OS, DstTree, Dst);
437     OS << " into ";
438     printNode(OS, DstTree, DstNode.Parent);
439     OS << " at " << DstTree.findPositionInParent(Dst) << "\n";
440     break;
441   }
442 }
443 
444 int main(int argc, const char **argv) {
445   std::string ErrorMessage;
446   std::unique_ptr<CompilationDatabase> CommonCompilations =
447       FixedCompilationDatabase::loadFromCommandLine(argc, argv, ErrorMessage);
448   if (!CommonCompilations && !ErrorMessage.empty())
449     llvm::errs() << ErrorMessage;
450   cl::HideUnrelatedOptions(ClangDiffCategory);
451   if (!cl::ParseCommandLineOptions(argc, argv)) {
452     cl::PrintOptionValues();
453     return 1;
454   }
455 
456   addExtraArgs(CommonCompilations);
457 
458   if (ASTDump || ASTDumpJson) {
459     if (!DestinationPath.empty()) {
460       llvm::errs() << "Error: Please specify exactly one filename.\n";
461       return 1;
462     }
463     std::unique_ptr<ASTUnit> AST = getAST(CommonCompilations, SourcePath);
464     if (!AST)
465       return 1;
466     diff::SyntaxTree Tree(AST->getASTContext());
467     if (ASTDump) {
468       printTree(llvm::outs(), Tree);
469       return 0;
470     }
471     llvm::outs() << R"({"filename":")";
472     printJsonString(llvm::outs(), SourcePath);
473     llvm::outs() << R"(","root":)";
474     printNodeAsJson(llvm::outs(), Tree, Tree.getRootId());
475     llvm::outs() << "}\n";
476     return 0;
477   }
478 
479   if (DestinationPath.empty()) {
480     llvm::errs() << "Error: Exactly two paths are required.\n";
481     return 1;
482   }
483 
484   std::unique_ptr<ASTUnit> Src = getAST(CommonCompilations, SourcePath);
485   std::unique_ptr<ASTUnit> Dst = getAST(CommonCompilations, DestinationPath);
486   if (!Src || !Dst)
487     return 1;
488 
489   diff::ComparisonOptions Options;
490   if (MaxSize != -1)
491     Options.MaxSize = MaxSize;
492   if (!StopAfter.empty()) {
493     if (StopAfter == "topdown")
494       Options.StopAfterTopDown = true;
495     else if (StopAfter != "bottomup") {
496       llvm::errs() << "Error: Invalid argument for -stop-after\n";
497       return 1;
498     }
499   }
500   diff::SyntaxTree SrcTree(Src->getASTContext());
501   diff::SyntaxTree DstTree(Dst->getASTContext());
502   diff::ASTDiff Diff(SrcTree, DstTree, Options);
503 
504   if (HtmlDiff) {
505     llvm::outs() << HtmlDiffHeader << "<pre>";
506     llvm::outs() << "<div id='L' class='code'>";
507     printHtmlForNode(llvm::outs(), Diff, SrcTree, true, SrcTree.getRootId(), 0);
508     llvm::outs() << "</div>";
509     llvm::outs() << "<div id='R' class='code'>";
510     printHtmlForNode(llvm::outs(), Diff, DstTree, false, DstTree.getRootId(),
511                      0);
512     llvm::outs() << "</div>";
513     llvm::outs() << "</pre></div></body></html>\n";
514     return 0;
515   }
516 
517   for (diff::NodeId Dst : DstTree) {
518     diff::NodeId Src = Diff.getMapped(DstTree, Dst);
519     if (PrintMatches && Src.isValid()) {
520       llvm::outs() << "Match ";
521       printNode(llvm::outs(), SrcTree, Src);
522       llvm::outs() << " to ";
523       printNode(llvm::outs(), DstTree, Dst);
524       llvm::outs() << "\n";
525     }
526     printDstChange(llvm::outs(), Diff, SrcTree, DstTree, Dst);
527   }
528   for (diff::NodeId Src : SrcTree) {
529     if (Diff.getMapped(SrcTree, Src).isInvalid()) {
530       llvm::outs() << "Delete ";
531       printNode(llvm::outs(), SrcTree, Src);
532       llvm::outs() << "\n";
533     }
534   }
535 
536   return 0;
537 }
538