1 //===--- TestTU.cpp - Scratch source files for testing --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===---------------------------------------------------------------------===//
9 #include "TestTU.h"
10 #include "TestFS.h"
11 #include "index/FileIndex.h"
12 #include "index/MemIndex.h"
13 #include "clang/AST/RecursiveASTVisitor.h"
14 #include "clang/Frontend/CompilerInvocation.h"
15 #include "clang/Frontend/PCHContainerOperations.h"
16 #include "clang/Frontend/Utils.h"
17 
18 namespace clang {
19 namespace clangd {
20 using namespace llvm;
21 
buildTestFS(llvm::StringMap<std::string> const & Files,llvm::StringMap<time_t> const & Timestamps)22 ParsedAST TestTU::build() const {
23   std::string FullFilename = testPath(Filename),
24               FullHeaderName = testPath(HeaderFilename);
25   std::vector<const char *> Cmd = {"clang", FullFilename.c_str()};
26   // FIXME: this shouldn't need to be conditional, but it breaks a
27   // GoToDefinition test for some reason (getMacroArgExpandedLocation fails).
28   if (!HeaderCode.empty()) {
29     Cmd.push_back("-include");
30     Cmd.push_back(FullHeaderName.c_str());
31   }
32   Cmd.insert(Cmd.end(), ExtraArgs.begin(), ExtraArgs.end());
33   auto AST = ParsedAST::build(
34       createInvocationFromCommandLine(Cmd), nullptr,
35       MemoryBuffer::getMemBufferCopy(Code),
36       std::make_shared<PCHContainerOperations>(),
37       buildTestFS({{FullFilename, Code}, {FullHeaderName, HeaderCode}}));
38   if (!AST.hasValue()) {
39     ADD_FAILURE() << "Failed to build code:\n" << Code;
40     llvm_unreachable("Failed to build TestTU!");
41   }
42   return std::move(*AST);
43 }
44 
45 SymbolSlab TestTU::headerSymbols() const {
46   auto AST = build();
47   return indexAST(AST.getASTContext(), AST.getPreprocessorPtr());
48 }
49 
50 std::unique_ptr<SymbolIndex> TestTU::index() const {
51   return MemIndex::build(headerSymbols());
52 }
testRoot()53 
54 const Symbol &findSymbol(const SymbolSlab &Slab, llvm::StringRef QName) {
55   const Symbol *Result = nullptr;
56   for (const Symbol &S : Slab) {
57     if (QName != (S.Scope + S.Name).str())
58       continue;
59     if (Result) {
60       ADD_FAILURE() << "Multiple symbols named " << QName << ":\n"
61                     << *Result << "\n---\n"
62                     << S;
63       assert(false && "QName is not unique");
64     }
65     Result = &S;
66   }
67   if (!Result) {
68     ADD_FAILURE() << "No symbol named " << QName << " in "
69                   << ::testing::PrintToString(Slab);
70     assert(false && "No symbol with QName");
71   }
72   return *Result;
73 }
74 
75 const NamedDecl &findDecl(ParsedAST &AST, llvm::StringRef QName) {
76   llvm::SmallVector<llvm::StringRef, 4> Components;
77   QName.split(Components, "::");
78 
getAbsolutePath(llvm::StringRef,llvm::StringRef Body,llvm::StringRef HintPath) const79   auto &Ctx = AST.getASTContext();
80   auto LookupDecl = [&Ctx](const DeclContext &Scope,
81                            llvm::StringRef Name) -> const NamedDecl & {
82     auto LookupRes = Scope.lookup(DeclarationName(&Ctx.Idents.get(Name)));
83     assert(!LookupRes.empty() && "Lookup failed");
84     assert(LookupRes.size() == 1 && "Lookup returned multiple results");
85     return *LookupRes.front();
86   };
87 
88   const DeclContext *Scope = Ctx.getTranslationUnitDecl();
89   for (auto NameIt = Components.begin(), End = Components.end() - 1;
90        NameIt != End; ++NameIt) {
91     Scope = &cast<DeclContext>(LookupDecl(*Scope, *NameIt));
uriFromAbsolutePath(llvm::StringRef AbsolutePath) const92   }
93   return LookupDecl(*Scope, Components.back());
94 }
95 
96 const NamedDecl &findAnyDecl(ParsedAST &AST,
97                              std::function<bool(const NamedDecl &)> Callback) {
98   struct Visitor : RecursiveASTVisitor<Visitor> {
99     decltype(Callback) CB;
100     llvm::SmallVector<const NamedDecl *, 1> Decls;
101     bool VisitNamedDecl(const NamedDecl *ND) {
102       if (CB(*ND))
103         Decls.push_back(ND);
104       return true;
105     }
106   } Visitor;
107   Visitor.CB = Callback;
108   for (Decl *D : AST.getLocalTopLevelDecls())
109     Visitor.TraverseDecl(D);
110   if (Visitor.Decls.size() != 1) {
111     ADD_FAILURE() << Visitor.Decls.size() << " symbols matched.";
112     assert(Visitor.Decls.size() == 1);
113   }
114   return *Visitor.Decls.front();
115 }
116 
117 const NamedDecl &findAnyDecl(ParsedAST &AST, llvm::StringRef Name) {
118   return findAnyDecl(AST, [Name](const NamedDecl &ND) {
119     if (auto *ID = ND.getIdentifier())
120       if (ID->getName() == Name)
121         return true;
122     return false;
123   });
124 }
125 
126 } // namespace clangd
127 } // namespace clang
128