1 //===- unittest/Tooling/CrossTranslationUnitTest.cpp - Tooling unit tests -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/CrossTU/CrossTranslationUnit.h"
10 #include "clang/AST/ASTConsumer.h"
11 #include "clang/AST/ParentMapContext.h"
12 #include "clang/Frontend/CompilerInstance.h"
13 #include "clang/Frontend/FrontendAction.h"
14 #include "clang/Tooling/Tooling.h"
15 #include "llvm/ADT/Optional.h"
16 #include "llvm/Support/FileSystem.h"
17 #include "llvm/Support/Path.h"
18 #include "llvm/Support/ToolOutputFile.h"
19 #include "gtest/gtest.h"
20 #include <cassert>
21 
22 namespace clang {
23 namespace cross_tu {
24 
25 namespace {
26 
27 class CTUASTConsumer : public clang::ASTConsumer {
28 public:
CTUASTConsumer(clang::CompilerInstance & CI,bool * Success)29   explicit CTUASTConsumer(clang::CompilerInstance &CI, bool *Success)
30       : CTU(CI), Success(Success) {}
31 
HandleTranslationUnit(ASTContext & Ctx)32   void HandleTranslationUnit(ASTContext &Ctx) {
33     auto FindFInTU = [](const TranslationUnitDecl *TU) {
34       const FunctionDecl *FD = nullptr;
35       for (const Decl *D : TU->decls()) {
36         FD = dyn_cast<FunctionDecl>(D);
37         if (FD && FD->getName() == "f")
38           break;
39       }
40       return FD;
41     };
42 
43     const TranslationUnitDecl *TU = Ctx.getTranslationUnitDecl();
44     const FunctionDecl *FD = FindFInTU(TU);
45     assert(FD && FD->getName() == "f");
46     bool OrigFDHasBody = FD->hasBody();
47 
48     const DynTypedNodeList ParentsBeforeImport =
49         Ctx.getParentMapContext().getParents<Decl>(*FD);
50     ASSERT_FALSE(ParentsBeforeImport.empty());
51 
52     // Prepare the index file and the AST file.
53     int ASTFD;
54     llvm::SmallString<256> ASTFileName;
55     ASSERT_FALSE(
56         llvm::sys::fs::createTemporaryFile("f_ast", "ast", ASTFD, ASTFileName));
57     llvm::ToolOutputFile ASTFile(ASTFileName, ASTFD);
58 
59     int IndexFD;
60     llvm::SmallString<256> IndexFileName;
61     ASSERT_FALSE(llvm::sys::fs::createTemporaryFile("index", "txt", IndexFD,
62                                                     IndexFileName));
63     llvm::ToolOutputFile IndexFile(IndexFileName, IndexFD);
64     IndexFile.os() << "c:@F@f#I# " << ASTFileName << "\n";
65     IndexFile.os().flush();
66     EXPECT_TRUE(llvm::sys::fs::exists(IndexFileName));
67 
68     StringRef SourceText = "int f(int) { return 0; }\n";
69     // This file must exist since the saved ASTFile will reference it.
70     int SourceFD;
71     llvm::SmallString<256> SourceFileName;
72     ASSERT_FALSE(llvm::sys::fs::createTemporaryFile("input", "cpp", SourceFD,
73                                                     SourceFileName));
74     llvm::ToolOutputFile SourceFile(SourceFileName, SourceFD);
75     SourceFile.os() << SourceText;
76     SourceFile.os().flush();
77     EXPECT_TRUE(llvm::sys::fs::exists(SourceFileName));
78 
79     std::unique_ptr<ASTUnit> ASTWithDefinition =
80         tooling::buildASTFromCode(SourceText, SourceFileName);
81     ASTWithDefinition->Save(ASTFileName.str());
82     EXPECT_TRUE(llvm::sys::fs::exists(ASTFileName));
83 
84     // Load the definition from the AST file.
85     llvm::Expected<const FunctionDecl *> NewFDorError = handleExpected(
86         CTU.getCrossTUDefinition(FD, "", IndexFileName, false),
87         []() { return nullptr; }, [](IndexError &) {});
88 
89     if (NewFDorError) {
90       const FunctionDecl *NewFD = *NewFDorError;
91       *Success = NewFD && NewFD->hasBody() && !OrigFDHasBody;
92 
93       if (NewFD) {
94         // Check GetImportedFromSourceLocation.
95         llvm::Optional<std::pair<SourceLocation, ASTUnit *>> SLocResult =
96             CTU.getImportedFromSourceLocation(NewFD->getLocation());
97         EXPECT_TRUE(SLocResult);
98         if (SLocResult) {
99           SourceLocation OrigSLoc = (*SLocResult).first;
100           ASTUnit *OrigUnit = (*SLocResult).second;
101           // OrigUnit is created internally by CTU (is not the
102           // ASTWithDefinition).
103           TranslationUnitDecl *OrigTU =
104               OrigUnit->getASTContext().getTranslationUnitDecl();
105           const FunctionDecl *FDWithDefinition = FindFInTU(OrigTU);
106           EXPECT_TRUE(FDWithDefinition);
107           if (FDWithDefinition) {
108             EXPECT_EQ(FDWithDefinition->getName(), "f");
109             EXPECT_TRUE(FDWithDefinition->isThisDeclarationADefinition());
110             EXPECT_EQ(OrigSLoc, FDWithDefinition->getLocation());
111           }
112         }
113 
114         // Check parent map.
115         const DynTypedNodeList ParentsAfterImport =
116             Ctx.getParentMapContext().getParents<Decl>(*FD);
117         const DynTypedNodeList ParentsOfImported =
118             Ctx.getParentMapContext().getParents<Decl>(*NewFD);
119         EXPECT_TRUE(
120             checkParentListsEq(ParentsBeforeImport, ParentsAfterImport));
121         EXPECT_FALSE(ParentsOfImported.empty());
122       }
123     }
124   }
125 
checkParentListsEq(const DynTypedNodeList & L1,const DynTypedNodeList & L2)126   static bool checkParentListsEq(const DynTypedNodeList &L1,
127                                  const DynTypedNodeList &L2) {
128     if (L1.size() != L2.size())
129       return false;
130     for (unsigned int I = 0; I < L1.size(); ++I)
131       if (L1[I] != L2[I])
132         return false;
133     return true;
134   }
135 
136 private:
137   CrossTranslationUnitContext CTU;
138   bool *Success;
139 };
140 
141 class CTUAction : public clang::ASTFrontendAction {
142 public:
CTUAction(bool * Success,unsigned OverrideLimit)143   CTUAction(bool *Success, unsigned OverrideLimit)
144       : Success(Success), OverrideLimit(OverrideLimit) {}
145 
146 protected:
147   std::unique_ptr<clang::ASTConsumer>
CreateASTConsumer(clang::CompilerInstance & CI,StringRef)148   CreateASTConsumer(clang::CompilerInstance &CI, StringRef) override {
149     CI.getAnalyzerOpts()->CTUImportThreshold = OverrideLimit;
150     CI.getAnalyzerOpts()->CTUImportCppThreshold = OverrideLimit;
151     return std::make_unique<CTUASTConsumer>(CI, Success);
152   }
153 
154 private:
155   bool *Success;
156   const unsigned OverrideLimit;
157 };
158 
159 } // end namespace
160 
TEST(CrossTranslationUnit,CanLoadFunctionDefinition)161 TEST(CrossTranslationUnit, CanLoadFunctionDefinition) {
162   bool Success = false;
163   EXPECT_TRUE(tooling::runToolOnCode(std::make_unique<CTUAction>(&Success, 1u),
164                                      "int f(int);"));
165   EXPECT_TRUE(Success);
166 }
167 
TEST(CrossTranslationUnit,RespectsLoadThreshold)168 TEST(CrossTranslationUnit, RespectsLoadThreshold) {
169   bool Success = false;
170   EXPECT_TRUE(tooling::runToolOnCode(std::make_unique<CTUAction>(&Success, 0u),
171                                      "int f(int);"));
172   EXPECT_FALSE(Success);
173 }
174 
TEST(CrossTranslationUnit,IndexFormatCanBeParsed)175 TEST(CrossTranslationUnit, IndexFormatCanBeParsed) {
176   llvm::StringMap<std::string> Index;
177   Index["a"] = "/b/f1";
178   Index["c"] = "/d/f2";
179   Index["e"] = "/f/f3";
180   std::string IndexText = createCrossTUIndexString(Index);
181 
182   int IndexFD;
183   llvm::SmallString<256> IndexFileName;
184   ASSERT_FALSE(llvm::sys::fs::createTemporaryFile("index", "txt", IndexFD,
185                                                   IndexFileName));
186   llvm::ToolOutputFile IndexFile(IndexFileName, IndexFD);
187   IndexFile.os() << IndexText;
188   IndexFile.os().flush();
189   EXPECT_TRUE(llvm::sys::fs::exists(IndexFileName));
190   llvm::Expected<llvm::StringMap<std::string>> IndexOrErr =
191       parseCrossTUIndex(IndexFileName);
192   EXPECT_TRUE((bool)IndexOrErr);
193   llvm::StringMap<std::string> ParsedIndex = IndexOrErr.get();
194   for (const auto &E : Index) {
195     EXPECT_TRUE(ParsedIndex.count(E.getKey()));
196     EXPECT_EQ(ParsedIndex[E.getKey()], E.getValue());
197   }
198   for (const auto &E : ParsedIndex)
199     EXPECT_TRUE(Index.count(E.getKey()));
200 }
201 
TEST(CrossTranslationUnit,EmptyInvocationListIsNotValid)202 TEST(CrossTranslationUnit, EmptyInvocationListIsNotValid) {
203   auto Input = "";
204 
205   llvm::Expected<InvocationListTy> Result = parseInvocationList(Input);
206   EXPECT_FALSE(static_cast<bool>(Result));
207   bool IsWrongFromatError = false;
208   llvm::handleAllErrors(Result.takeError(), [&](IndexError &Err) {
209     IsWrongFromatError =
210         Err.getCode() == index_error_code::invocation_list_wrong_format;
211   });
212   EXPECT_TRUE(IsWrongFromatError);
213 }
214 
TEST(CrossTranslationUnit,AmbiguousInvocationListIsDetected)215 TEST(CrossTranslationUnit, AmbiguousInvocationListIsDetected) {
216   // The same source file occurs twice (for two different architecture) in
217   // this test case. The disambiguation is the responsibility of the user.
218   auto Input = R"(
219   /tmp/main.cpp:
220     - clang++
221     - -c
222     - -m32
223     - -o
224     - main32.o
225     - /tmp/main.cpp
226   /tmp/main.cpp:
227     - clang++
228     - -c
229     - -m64
230     - -o
231     - main64.o
232     - /tmp/main.cpp
233   )";
234 
235   llvm::Expected<InvocationListTy> Result = parseInvocationList(Input);
236   EXPECT_FALSE(static_cast<bool>(Result));
237   bool IsAmbiguousError = false;
238   llvm::handleAllErrors(Result.takeError(), [&](IndexError &Err) {
239     IsAmbiguousError =
240         Err.getCode() == index_error_code::invocation_list_ambiguous;
241   });
242   EXPECT_TRUE(IsAmbiguousError);
243 }
244 
TEST(CrossTranslationUnit,SingleInvocationCanBeParsed)245 TEST(CrossTranslationUnit, SingleInvocationCanBeParsed) {
246   auto Input = R"(
247   /tmp/main.cpp:
248     - clang++
249     - /tmp/main.cpp
250   )";
251   llvm::Expected<InvocationListTy> Result = parseInvocationList(Input);
252   EXPECT_TRUE(static_cast<bool>(Result));
253 
254   EXPECT_EQ(Result->size(), 1u);
255 
256   auto It = Result->find("/tmp/main.cpp");
257   EXPECT_TRUE(It != Result->end());
258   EXPECT_EQ(It->getValue()[0], "clang++");
259   EXPECT_EQ(It->getValue()[1], "/tmp/main.cpp");
260 }
261 
TEST(CrossTranslationUnit,MultipleInvocationsCanBeParsed)262 TEST(CrossTranslationUnit, MultipleInvocationsCanBeParsed) {
263   auto Input = R"(
264   /tmp/main.cpp:
265     - clang++
266     - /tmp/other.o
267     - /tmp/main.cpp
268   /tmp/other.cpp:
269     - g++
270     - -c
271     - -o
272     - /tmp/other.o
273     - /tmp/other.cpp
274   )";
275   llvm::Expected<InvocationListTy> Result = parseInvocationList(Input);
276   EXPECT_TRUE(static_cast<bool>(Result));
277 
278   EXPECT_EQ(Result->size(), 2u);
279 
280   auto It = Result->find("/tmp/main.cpp");
281   EXPECT_TRUE(It != Result->end());
282   EXPECT_EQ(It->getKey(), "/tmp/main.cpp");
283   EXPECT_EQ(It->getValue()[0], "clang++");
284   EXPECT_EQ(It->getValue()[1], "/tmp/other.o");
285   EXPECT_EQ(It->getValue()[2], "/tmp/main.cpp");
286 
287   It = Result->find("/tmp/other.cpp");
288   EXPECT_TRUE(It != Result->end());
289   EXPECT_EQ(It->getValue()[0], "g++");
290   EXPECT_EQ(It->getValue()[1], "-c");
291   EXPECT_EQ(It->getValue()[2], "-o");
292   EXPECT_EQ(It->getValue()[3], "/tmp/other.o");
293   EXPECT_EQ(It->getValue()[4], "/tmp/other.cpp");
294 }
295 
296 } // end namespace cross_tu
297 } // end namespace clang
298