1 //===- ASTSrcLocProcessor.cpp --------------------------------*- C++ -*----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ASTSrcLocProcessor.h"
10 
11 #include "clang/Frontend/CompilerInstance.h"
12 #include "llvm/Support/JSON.h"
13 #include "llvm/Support/MemoryBuffer.h"
14 
15 using namespace clang::tooling;
16 using namespace llvm;
17 using namespace clang::ast_matchers;
18 
19 ASTSrcLocProcessor::ASTSrcLocProcessor(StringRef JsonPath)
20     : JsonPath(JsonPath) {
21 
22   MatchFinder::MatchFinderOptions FinderOptions;
23 
24   Finder = std::make_unique<MatchFinder>(std::move(FinderOptions));
25   Finder->addMatcher(
26       cxxRecordDecl(
27           isDefinition(),
28           isSameOrDerivedFrom(
29               namedDecl(
30                   hasAnyName(
31                       "clang::Stmt", "clang::Decl", "clang::CXXCtorInitializer",
32                       "clang::NestedNameSpecifierLoc",
33                       "clang::TemplateArgumentLoc", "clang::CXXBaseSpecifier",
34                       "clang::DeclarationNameInfo", "clang::TypeLoc"))
35                   .bind("nodeClade")),
36           optionally(isDerivedFrom(cxxRecordDecl().bind("derivedFrom"))))
37           .bind("className"),
38       this);
39   Finder->addMatcher(
40           cxxRecordDecl(isDefinition(), hasAnyName("clang::PointerLikeTypeLoc",
41                                                    "clang::TypeofLikeTypeLoc"))
42               .bind("templateName"),
43       this);
44 }
45 
46 std::unique_ptr<clang::ASTConsumer>
47 ASTSrcLocProcessor::createASTConsumer(clang::CompilerInstance &Compiler,
48                                       StringRef File) {
49   return Finder->newASTConsumer();
50 }
51 
52 llvm::json::Object toJSON(llvm::StringMap<std::vector<StringRef>> const &Obj) {
53   using llvm::json::toJSON;
54 
55   llvm::json::Object JsonObj;
56   for (const auto &Item : Obj) {
57     JsonObj[Item.first()] = Item.second;
58   }
59   return JsonObj;
60 }
61 
62 llvm::json::Object toJSON(llvm::StringMap<std::string> const &Obj) {
63   using llvm::json::toJSON;
64 
65   llvm::json::Object JsonObj;
66   for (const auto &Item : Obj) {
67     JsonObj[Item.first()] = Item.second;
68   }
69   return JsonObj;
70 }
71 
72 llvm::json::Object toJSON(ClassData const &Obj) {
73   llvm::json::Object JsonObj;
74 
75   if (!Obj.ASTClassLocations.empty())
76     JsonObj["sourceLocations"] = Obj.ASTClassLocations;
77   if (!Obj.ASTClassRanges.empty())
78     JsonObj["sourceRanges"] = Obj.ASTClassRanges;
79   if (!Obj.TemplateParms.empty())
80     JsonObj["templateParms"] = Obj.TemplateParms;
81   if (!Obj.TypeSourceInfos.empty())
82     JsonObj["typeSourceInfos"] = Obj.TypeSourceInfos;
83   if (!Obj.TypeLocs.empty())
84     JsonObj["typeLocs"] = Obj.TypeLocs;
85   if (!Obj.NestedNameLocs.empty())
86     JsonObj["nestedNameLocs"] = Obj.NestedNameLocs;
87   if (!Obj.DeclNameInfos.empty())
88     JsonObj["declNameInfos"] = Obj.DeclNameInfos;
89   return JsonObj;
90 }
91 
92 llvm::json::Object toJSON(llvm::StringMap<ClassData> const &Obj) {
93   using llvm::json::toJSON;
94 
95   llvm::json::Object JsonObj;
96   for (const auto &Item : Obj)
97     JsonObj[Item.first()] = ::toJSON(Item.second);
98   return JsonObj;
99 }
100 
101 void WriteJSON(StringRef JsonPath, llvm::json::Object &&ClassInheritance,
102                llvm::json::Object &&ClassesInClade,
103                llvm::json::Object &&ClassEntries) {
104   llvm::json::Object JsonObj;
105 
106   using llvm::json::toJSON;
107 
108   JsonObj["classInheritance"] = std::move(ClassInheritance);
109   JsonObj["classesInClade"] = std::move(ClassesInClade);
110   JsonObj["classEntries"] = std::move(ClassEntries);
111 
112   llvm::json::Value JsonVal(std::move(JsonObj));
113 
114   bool WriteChange = false;
115   std::string OutString;
116   if (auto ExistingOrErr = MemoryBuffer::getFile(JsonPath, /*IsText=*/true)) {
117     raw_string_ostream Out(OutString);
118     Out << formatv("{0:2}", JsonVal);
119     if (ExistingOrErr.get()->getBuffer() == Out.str())
120       return;
121     WriteChange = true;
122   }
123 
124   std::error_code EC;
125   llvm::raw_fd_ostream JsonOut(JsonPath, EC, llvm::sys::fs::OF_Text);
126   if (EC)
127     return;
128 
129   if (WriteChange)
130     JsonOut << OutString;
131   else
132     JsonOut << formatv("{0:2}", JsonVal);
133 }
134 
135 void ASTSrcLocProcessor::generate() {
136   WriteJSON(JsonPath, ::toJSON(ClassInheritance), ::toJSON(ClassesInClade),
137             ::toJSON(ClassEntries));
138 }
139 
140 void ASTSrcLocProcessor::generateEmpty() { WriteJSON(JsonPath, {}, {}, {}); }
141 
142 std::vector<std::string>
143 CaptureMethods(std::string TypeString, const clang::CXXRecordDecl *ASTClass,
144                const MatchFinder::MatchResult &Result) {
145 
146   auto publicAccessor = [](auto... InnerMatcher) {
147     return cxxMethodDecl(isPublic(), parameterCountIs(0), isConst(),
148                          InnerMatcher...);
149   };
150 
151   auto BoundNodesVec = match(
152       findAll(
153           publicAccessor(
154               ofClass(cxxRecordDecl(
155                   equalsNode(ASTClass),
156                   optionally(isDerivedFrom(
157                       cxxRecordDecl(hasAnyName("clang::Stmt", "clang::Decl"))
158                           .bind("stmtOrDeclBase"))),
159                   optionally(isDerivedFrom(
160                       cxxRecordDecl(hasName("clang::Expr")).bind("exprBase"))),
161                   optionally(
162                       isDerivedFrom(cxxRecordDecl(hasName("clang::TypeLoc"))
163                                         .bind("typeLocBase"))))),
164               returns(hasCanonicalType(asString(TypeString))))
165               .bind("classMethod")),
166       *ASTClass, *Result.Context);
167 
168   std::vector<std::string> Methods;
169   for (const auto &BN : BoundNodesVec) {
170     if (const auto *Node = BN.getNodeAs<clang::NamedDecl>("classMethod")) {
171       const auto *StmtOrDeclBase =
172           BN.getNodeAs<clang::CXXRecordDecl>("stmtOrDeclBase");
173       const auto *TypeLocBase =
174           BN.getNodeAs<clang::CXXRecordDecl>("typeLocBase");
175       const auto *ExprBase = BN.getNodeAs<clang::CXXRecordDecl>("exprBase");
176       // The clang AST has several methods on base classes which are overriden
177       // pseudo-virtually by derived classes.
178       // We record only the pseudo-virtual methods on the base classes to
179       // avoid duplication.
180       if (StmtOrDeclBase &&
181           (Node->getName() == "getBeginLoc" || Node->getName() == "getEndLoc" ||
182            Node->getName() == "getSourceRange"))
183         continue;
184       if (ExprBase && Node->getName() == "getExprLoc")
185         continue;
186       if (TypeLocBase && Node->getName() == "getLocalSourceRange")
187         continue;
188       if ((ASTClass->getName() == "PointerLikeTypeLoc" ||
189            ASTClass->getName() == "TypeofLikeTypeLoc") &&
190           Node->getName() == "getLocalSourceRange")
191         continue;
192       Methods.push_back(Node->getName().str());
193     }
194   }
195   return Methods;
196 }
197 
198 void ASTSrcLocProcessor::run(const MatchFinder::MatchResult &Result) {
199 
200   const auto *ASTClass =
201       Result.Nodes.getNodeAs<clang::CXXRecordDecl>("className");
202 
203   StringRef CladeName;
204   if (ASTClass) {
205     if (const auto *NodeClade =
206             Result.Nodes.getNodeAs<clang::CXXRecordDecl>("nodeClade"))
207       CladeName = NodeClade->getName();
208   } else {
209     ASTClass = Result.Nodes.getNodeAs<clang::CXXRecordDecl>("templateName");
210     CladeName = "TypeLoc";
211   }
212 
213   StringRef ClassName = ASTClass->getName();
214 
215   ClassData CD;
216 
217   CD.ASTClassLocations =
218       CaptureMethods("class clang::SourceLocation", ASTClass, Result);
219   CD.ASTClassRanges =
220       CaptureMethods("class clang::SourceRange", ASTClass, Result);
221   CD.TypeSourceInfos =
222       CaptureMethods("class clang::TypeSourceInfo *", ASTClass, Result);
223   CD.TypeLocs = CaptureMethods("class clang::TypeLoc", ASTClass, Result);
224   CD.NestedNameLocs =
225       CaptureMethods("class clang::NestedNameSpecifierLoc", ASTClass, Result);
226   CD.DeclNameInfos =
227       CaptureMethods("struct clang::DeclarationNameInfo", ASTClass, Result);
228   auto DI = CaptureMethods("const struct clang::DeclarationNameInfo &",
229                            ASTClass, Result);
230   CD.DeclNameInfos.insert(CD.DeclNameInfos.end(), DI.begin(), DI.end());
231 
232   if (const auto *DerivedFrom =
233           Result.Nodes.getNodeAs<clang::CXXRecordDecl>("derivedFrom")) {
234 
235     if (const auto *Templ =
236             llvm::dyn_cast<clang::ClassTemplateSpecializationDecl>(
237                 DerivedFrom)) {
238 
239       const auto &TArgs = Templ->getTemplateArgs();
240 
241       SmallString<256> TArgsString;
242       llvm::raw_svector_ostream OS(TArgsString);
243       OS << DerivedFrom->getName() << '<';
244 
245       clang::PrintingPolicy PPol(Result.Context->getLangOpts());
246       PPol.TerseOutput = true;
247 
248       for (unsigned I = 0; I < TArgs.size(); ++I) {
249         if (I > 0)
250           OS << ", ";
251         TArgs.get(I).getAsType().print(OS, PPol);
252       }
253       OS << '>';
254 
255       ClassInheritance[ClassName] = TArgsString.str().str();
256     } else {
257       ClassInheritance[ClassName] = DerivedFrom->getName().str();
258     }
259   }
260 
261   if (const auto *Templ = ASTClass->getDescribedClassTemplate()) {
262     if (auto *TParams = Templ->getTemplateParameters()) {
263       for (const auto &TParam : *TParams) {
264         CD.TemplateParms.push_back(TParam->getName().str());
265       }
266     }
267   }
268 
269   ClassEntries[ClassName] = CD;
270   ClassesInClade[CladeName].push_back(ClassName);
271 }
272