1 //===- ClangDiff.cpp - compare source files by AST nodes ------*- C++ -*- -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a tool for syntax tree based comparison using
10 // Tooling/ASTDiff.
11 //
12 //===----------------------------------------------------------------------===//
14 #include "clang/Tooling/ASTDiff/ASTDiff.h"
15 #include "clang/Tooling/CommonOptionsParser.h"
16 #include "clang/Tooling/Tooling.h"
17 #include "llvm/Support/CommandLine.h"
19 using namespace llvm;
20 using namespace clang;
21 using namespace clang::tooling;
23 static cl::OptionCategory ClangDiffCategory("clang-diff options");
25 static cl::opt<bool>
26     ASTDump("ast-dump",
27             cl::desc("Print the internal representation of the AST."),
28             cl::init(false), cl::cat(ClangDiffCategory));
30 static cl::opt<bool> ASTDumpJson(
31     "ast-dump-json",
32     cl::desc("Print the internal representation of the AST as JSON."),
33     cl::init(false), cl::cat(ClangDiffCategory));
35 static cl::opt<bool> PrintMatches("dump-matches",
36                                   cl::desc("Print the matched nodes."),
37                                   cl::init(false), cl::cat(ClangDiffCategory));
39 static cl::opt<bool> HtmlDiff("html",
40                               cl::desc("Output a side-by-side diff in HTML."),
41                               cl::init(false), cl::cat(ClangDiffCategory));
43 static cl::opt<std::string> SourcePath(cl::Positional, cl::desc("<source>"),
44                                        cl::Required,
45                                        cl::cat(ClangDiffCategory));
47 static cl::opt<std::string> DestinationPath(cl::Positional,
48                                             cl::desc("<destination>"),
49                                             cl::Optional,
50                                             cl::cat(ClangDiffCategory));
52 static cl::opt<std::string> StopAfter("stop-diff-after",
53                                       cl::desc("<topdown|bottomup>"),
54                                       cl::Optional, cl::init(""),
55                                       cl::cat(ClangDiffCategory));
57 static cl::opt<int> MaxSize("s", cl::desc("<maxsize>"), cl::Optional,
58                             cl::init(-1), cl::cat(ClangDiffCategory));
60 static cl::opt<std::string> BuildPath("p", cl::desc("Build path"), cl::init(""),
61                                       cl::Optional, cl::cat(ClangDiffCategory));
63 static cl::list<std::string> ArgsAfter(
64     "extra-arg",
65     cl::desc("Additional argument to append to the compiler command line"),
66     cl::cat(ClangDiffCategory));
68 static cl::list<std::string> ArgsBefore(
69     "extra-arg-before",
70     cl::desc("Additional argument to prepend to the compiler command line"),
71     cl::cat(ClangDiffCategory));
73 static void addExtraArgs(std::unique_ptr<CompilationDatabase> &Compilations) {
74   if (!Compilations)
75     return;
76   auto AdjustingCompilations =
77       std::make_unique<ArgumentsAdjustingCompilations>(
78           std::move(Compilations));
79   AdjustingCompilations->appendArgumentsAdjuster(
80       getInsertArgumentAdjuster(ArgsBefore, ArgumentInsertPosition::BEGIN));
81   AdjustingCompilations->appendArgumentsAdjuster(
82       getInsertArgumentAdjuster(ArgsAfter, ArgumentInsertPosition::END));
83   Compilations = std::move(AdjustingCompilations);
84 }
86 static std::unique_ptr<ASTUnit>
87 getAST(const std::unique_ptr<CompilationDatabase> &CommonCompilations,
88        const StringRef Filename) {
89   std::string ErrorMessage;
90   std::unique_ptr<CompilationDatabase> Compilations;
91   if (!CommonCompilations) {
92     Compilations = CompilationDatabase::autoDetectFromSource(
93         BuildPath.empty() ? Filename : BuildPath, ErrorMessage);
94     if (!Compilations) {
95       llvm::errs()
96           << "Error while trying to load a compilation database, running "
97              "without flags.\n"
98           << ErrorMessage;
99       Compilations =
100           std::make_unique<clang::tooling::FixedCompilationDatabase>(
101               ".", std::vector<std::string>());
102     }
103   }
104   addExtraArgs(Compilations);
105   std::array<std::string, 1> Files = {{std::string(Filename)}};
106   ClangTool Tool(Compilations ? *Compilations : *CommonCompilations, Files);
107   std::vector<std::unique_ptr<ASTUnit>> ASTs;
108   Tool.buildASTs(ASTs);
109   if (ASTs.size() != Files.size())
110     return nullptr;
111   return std::move(ASTs[0]);
112 }
114 static char hexdigit(int N) { return N &= 0xf, N + (N < 10 ? '0' : 'a' - 10); }
116 static const char HtmlDiffHeader[] = R"(
117 <html>
118 <head>
119 <meta charset='utf-8'/>
120 <style>
121 span.d { color: red; }
122 span.u { color: #cc00cc; }
123 span.i { color: green; }
124 span.m { font-weight: bold; }
125 span   { font-weight: normal; color: black; }
126 div.code {
127   width: 48%;
128   height: 98%;
129   overflow: scroll;
130   float: left;
131   padding: 0 0 0.5% 0.5%;
132   border: solid 2px LightGrey;
133   border-radius: 5px;
134 }
135 </style>
136 </head>
137 <script type='text/javascript'>
138 highlightStack = []
139 function clearHighlight() {
140   while (highlightStack.length) {
141     var [l, r] = highlightStack.pop()
142     document.getElementById(l).style.backgroundColor = 'inherit'
143     if (r[1] != '-')
144       document.getElementById(r).style.backgroundColor = 'inherit'
145   }
146 }
147 function highlight(event) {
148   var id = event.target['id']
149   doHighlight(id)
150 }
151 function doHighlight(id) {
152   clearHighlight()
153   source = document.getElementById(id)
154   if (!source.attributes['tid'])
155     return
156   var mapped = source
157   while (mapped && mapped.parentElement && mapped.attributes['tid'].value.substr(1) === '-1')
158     mapped = mapped.parentElement
159   var tid = null, target = null
160   if (mapped) {
161     tid = mapped.attributes['tid'].value
162     target = document.getElementById(tid)
163   }
164   if (source.parentElement && source.parentElement.classList.contains('code'))
165     return
166   source.style.backgroundColor = 'lightgrey'
167   source.scrollIntoView()
168   if (target) {
169     if (mapped === source)
170       target.style.backgroundColor = 'lightgrey'
171     target.scrollIntoView()
172   }
173   highlightStack.push([id, tid])
174   location.hash = '#' + id
175 }
176 function scrollToBoth() {
177   doHighlight(location.hash.substr(1))
178 }
179 function changed(elem) {
180   return elem.classList.length == 0
181 }
182 function nextChangedNode(prefix, increment, number) {
183   do {
184     number += increment
185     var elem = document.getElementById(prefix + number)
186   } while(elem && !changed(elem))
187   return elem ? number : null
188 }
189 function handleKey(e) {
190   var down = e.code === "KeyJ"
191   var up = e.code === "KeyK"
192   if (!down && !up)
193     return
194   var id = highlightStack[0] ? highlightStack[0][0] : 'R0'
195   var oldelem = document.getElementById(id)
196   var number = parseInt(id.substr(1))
197   var increment = down ? 1 : -1
198   var lastnumber = number
199   var prefix = id[0]
200   do {
201     number = nextChangedNode(prefix, increment, number)
202     var elem = document.getElementById(prefix + number)
203     if (up && elem) {
204       while (elem.parentElement && changed(elem.parentElement))
205         elem = elem.parentElement
206       number = elem.id.substr(1)
207     }
208   } while ((down && id !== 'R0' && oldelem.contains(elem)))
209   if (!number)
210     number = lastnumber
211   elem = document.getElementById(prefix + number)
212   doHighlight(prefix + number)
213 }
214 window.onload = scrollToBoth
215 window.onkeydown = handleKey
216 </script>
217 <body>
218 <div onclick='highlight(event)'>
219 )";
221 static void printHtml(raw_ostream &OS, char C) {
222   switch (C) {
223   case '&':
224     OS << "&amp;";
225     break;
226   case '<':
227     OS << "&lt;";
228     break;
229   case '>':
230     OS << "&gt;";
231     break;
232   case '\'':
233     OS << "&#x27;";
234     break;
235   case '"':
236     OS << "&quot;";
237     break;
238   default:
239     OS << C;
240   }
241 }
243 static void printHtml(raw_ostream &OS, const StringRef Str) {
244   for (char C : Str)
245     printHtml(OS, C);
246 }
248 static std::string getChangeKindAbbr(diff::ChangeKind Kind) {
249   switch (Kind) {
250   case diff::None:
251     return "";
252   case diff::Delete:
253     return "d";
254   case diff::Update:
255     return "u";
256   case diff::Insert:
257     return "i";
258   case diff::Move:
259     return "m";
260   case diff::UpdateMove:
261     return "u m";
262   }
263   llvm_unreachable("Invalid enumeration value.");
264 }
266 static unsigned printHtmlForNode(raw_ostream &OS, const diff::ASTDiff &Diff,
267                                  diff::SyntaxTree &Tree, bool IsLeft,
268                                  diff::NodeId Id, unsigned Offset) {
269   const diff::Node &Node = Tree.getNode(Id);
270   char MyTag, OtherTag;
271   diff::NodeId LeftId, RightId;
272   diff::NodeId TargetId = Diff.getMapped(Tree, Id);
273   if (IsLeft) {
274     MyTag = 'L';
275     OtherTag = 'R';
276     LeftId = Id;
277     RightId = TargetId;
278   } else {
279     MyTag = 'R';
280     OtherTag = 'L';
281     LeftId = TargetId;
282     RightId = Id;
283   }
284   unsigned Begin, End;
285   std::tie(Begin, End) = Tree.getSourceRangeOffsets(Node);
286   const SourceManager &SrcMgr = Tree.getASTContext().getSourceManager();
287   auto Code = SrcMgr.getBufferOrFake(SrcMgr.getMainFileID()).getBuffer();
288   for (; Offset < Begin; ++Offset)
289     printHtml(OS, Code[Offset]);
290   OS << "<span id='" << MyTag << Id << "' "
291      << "tid='" << OtherTag << TargetId << "' ";
292   OS << "title='";
293   printHtml(OS, Node.getTypeLabel());
294   OS << "\n" << LeftId << " -> " << RightId;
295   std::string Value = Tree.getNodeValue(Node);
296   if (!Value.empty()) {
297     OS << "\n";
298     printHtml(OS, Value);
299   }
300   OS << "'";
301   if (Node.Change != diff::None)
302     OS << " class='" << getChangeKindAbbr(Node.Change) << "'";
303   OS << ">";
305   for (diff::NodeId Child : Node.Children)
306     Offset = printHtmlForNode(OS, Diff, Tree, IsLeft, Child, Offset);
308   for (; Offset < End; ++Offset)
309     printHtml(OS, Code[Offset]);
310   if (Id == Tree.getRootId()) {
311     End = Code.size();
312     for (; Offset < End; ++Offset)
313       printHtml(OS, Code[Offset]);
314   }
315   OS << "</span>";
316   return Offset;
317 }
319 static void printJsonString(raw_ostream &OS, const StringRef Str) {
320   for (signed char C : Str) {
321     switch (C) {
322     case '"':
323       OS << R"(\")";
324       break;
325     case '\\':
326       OS << R"(\\)";
327       break;
328     case '\n':
329       OS << R"(\n)";
330       break;
331     case '\t':
332       OS << R"(\t)";
333       break;
334     default:
335       if ('\x00' <= C && C <= '\x1f') {
336         OS << R"(\u00)" << hexdigit(C >> 4) << hexdigit(C);
337       } else {
338         OS << C;
339       }
340     }
341   }
342 }
344 static void printNodeAttributes(raw_ostream &OS, diff::SyntaxTree &Tree,
345                                 diff::NodeId Id) {
346   const diff::Node &N = Tree.getNode(Id);
347   OS << R"("id":)" << int(Id);
348   OS << R"(,"type":")" << N.getTypeLabel() << '"';
349   auto Offsets = Tree.getSourceRangeOffsets(N);
350   OS << R"(,"begin":)" << Offsets.first;
351   OS << R"(,"end":)" << Offsets.second;
352   std::string Value = Tree.getNodeValue(N);
353   if (!Value.empty()) {
354     OS << R"(,"value":")";
355     printJsonString(OS, Value);
356     OS << '"';
357   }
358 }
360 static void printNodeAsJson(raw_ostream &OS, diff::SyntaxTree &Tree,
361                             diff::NodeId Id) {
362   const diff::Node &N = Tree.getNode(Id);
363   OS << "{";
364   printNodeAttributes(OS, Tree, Id);
365   auto Identifier = N.getIdentifier();
366   auto QualifiedIdentifier = N.getQualifiedIdentifier();
367   if (Identifier) {
368     OS << R"(,"identifier":")";
369     printJsonString(OS, *Identifier);
370     OS << R"(")";
371     if (QualifiedIdentifier && *Identifier != *QualifiedIdentifier) {
372       OS << R"(,"qualified_identifier":")";
373       printJsonString(OS, *QualifiedIdentifier);
374       OS << R"(")";
375     }
376   }
377   OS << R"(,"children":[)";
378   if (N.Children.size() > 0) {
379     printNodeAsJson(OS, Tree, N.Children[0]);
380     for (size_t I = 1, E = N.Children.size(); I < E; ++I) {
381       OS << ",";
382       printNodeAsJson(OS, Tree, N.Children[I]);
383     }
384   }
385   OS << "]}";
386 }
388 static void printNode(raw_ostream &OS, diff::SyntaxTree &Tree,
389                       diff::NodeId Id) {
390   if (Id.isInvalid()) {
391     OS << "None";
392     return;
393   }
394   OS << Tree.getNode(Id).getTypeLabel();
395   std::string Value = Tree.getNodeValue(Id);
396   if (!Value.empty())
397     OS << ": " << Value;
398   OS << "(" << Id << ")";
399 }
401 static void printTree(raw_ostream &OS, diff::SyntaxTree &Tree) {
402   for (diff::NodeId Id : Tree) {
403     for (int I = 0; I < Tree.getNode(Id).Depth; ++I)
404       OS << " ";
405     printNode(OS, Tree, Id);
406     OS << "\n";
407   }
408 }
410 static void printDstChange(raw_ostream &OS, diff::ASTDiff &Diff,
411                            diff::SyntaxTree &SrcTree, diff::SyntaxTree &DstTree,
412                            diff::NodeId Dst) {
413   const diff::Node &DstNode = DstTree.getNode(Dst);
414   diff::NodeId Src = Diff.getMapped(DstTree, Dst);
415   switch (DstNode.Change) {
416   case diff::None:
417     break;
418   case diff::Delete:
419     llvm_unreachable("The destination tree can't have deletions.");
420   case diff::Update:
421     OS << "Update ";
422     printNode(OS, SrcTree, Src);
423     OS << " to " << DstTree.getNodeValue(Dst) << "\n";
424     break;
425   case diff::Insert:
426   case diff::Move:
427   case diff::UpdateMove:
428     if (DstNode.Change == diff::Insert)
429       OS << "Insert";
430     else if (DstNode.Change == diff::Move)
431       OS << "Move";
432     else if (DstNode.Change == diff::UpdateMove)
433       OS << "Update and Move";
434     OS << " ";
435     printNode(OS, DstTree, Dst);
436     OS << " into ";
437     printNode(OS, DstTree, DstNode.Parent);
438     OS << " at " << DstTree.findPositionInParent(Dst) << "\n";
439     break;
440   }
441 }
443 int main(int argc, const char **argv) {
444   std::string ErrorMessage;
445   std::unique_ptr<CompilationDatabase> CommonCompilations =
446       FixedCompilationDatabase::loadFromCommandLine(argc, argv, ErrorMessage);
447   if (!CommonCompilations && !ErrorMessage.empty())
448     llvm::errs() << ErrorMessage;
449   cl::HideUnrelatedOptions(ClangDiffCategory);
450   if (!cl::ParseCommandLineOptions(argc, argv)) {
451     cl::PrintOptionValues();
452     return 1;
453   }
455   addExtraArgs(CommonCompilations);
457   if (ASTDump || ASTDumpJson) {
458     if (!DestinationPath.empty()) {
459       llvm::errs() << "Error: Please specify exactly one filename.\n";
460       return 1;
461     }
462     std::unique_ptr<ASTUnit> AST = getAST(CommonCompilations, SourcePath);
463     if (!AST)
464       return 1;
465     diff::SyntaxTree Tree(AST->getASTContext());
466     if (ASTDump) {
467       printTree(llvm::outs(), Tree);
468       return 0;
469     }
470     llvm::outs() << R"({"filename":")";
471     printJsonString(llvm::outs(), SourcePath);
472     llvm::outs() << R"(","root":)";
473     printNodeAsJson(llvm::outs(), Tree, Tree.getRootId());
474     llvm::outs() << "}\n";
475     return 0;
476   }
478   if (DestinationPath.empty()) {
479     llvm::errs() << "Error: Exactly two paths are required.\n";
480     return 1;
481   }
483   std::unique_ptr<ASTUnit> Src = getAST(CommonCompilations, SourcePath);
484   std::unique_ptr<ASTUnit> Dst = getAST(CommonCompilations, DestinationPath);
485   if (!Src || !Dst)
486     return 1;
488   diff::ComparisonOptions Options;
489   if (MaxSize != -1)
490     Options.MaxSize = MaxSize;
491   if (!StopAfter.empty()) {
492     if (StopAfter == "topdown")
493       Options.StopAfterTopDown = true;
494     else if (StopAfter != "bottomup") {
495       llvm::errs() << "Error: Invalid argument for -stop-after\n";
496       return 1;
497     }
498   }
499   diff::SyntaxTree SrcTree(Src->getASTContext());
500   diff::SyntaxTree DstTree(Dst->getASTContext());
501   diff::ASTDiff Diff(SrcTree, DstTree, Options);
503   if (HtmlDiff) {
504     llvm::outs() << HtmlDiffHeader << "<pre>";
505     llvm::outs() << "<div id='L' class='code'>";
506     printHtmlForNode(llvm::outs(), Diff, SrcTree, true, SrcTree.getRootId(), 0);
507     llvm::outs() << "</div>";
508     llvm::outs() << "<div id='R' class='code'>";
509     printHtmlForNode(llvm::outs(), Diff, DstTree, false, DstTree.getRootId(),
510                      0);
511     llvm::outs() << "</div>";
512     llvm::outs() << "</pre></div></body></html>\n";
513     return 0;
514   }
516   for (diff::NodeId Dst : DstTree) {
517     diff::NodeId Src = Diff.getMapped(DstTree, Dst);
518     if (PrintMatches && Src.isValid()) {
519       llvm::outs() << "Match ";
520       printNode(llvm::outs(), SrcTree, Src);
521       llvm::outs() << " to ";
522       printNode(llvm::outs(), DstTree, Dst);
523       llvm::outs() << "\n";
524     }
525     printDstChange(llvm::outs(), Diff, SrcTree, DstTree, Dst);
526   }
527   for (diff::NodeId Src : SrcTree) {
528     if (Diff.getMapped(SrcTree, Src).isInvalid()) {
529       llvm::outs() << "Delete ";
530       printNode(llvm::outs(), SrcTree, Src);
531       llvm::outs() << "\n";
532     }
533   }
535   return 0;
536 }