1 //===-- ClangIncludeFixer.cpp - Standalone include fixer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "FuzzySymbolIndex.h"
10 #include "InMemorySymbolIndex.h"
11 #include "IncludeFixer.h"
12 #include "IncludeFixerContext.h"
13 #include "SymbolIndexManager.h"
14 #include "YamlSymbolIndex.h"
15 #include "clang/Format/Format.h"
16 #include "clang/Frontend/TextDiagnosticPrinter.h"
17 #include "clang/Rewrite/Core/Rewriter.h"
18 #include "clang/Tooling/CommonOptionsParser.h"
19 #include "clang/Tooling/Core/Replacement.h"
20 #include "clang/Tooling/Tooling.h"
21 #include "llvm/Support/CommandLine.h"
22 #include "llvm/Support/Path.h"
23 #include "llvm/Support/YAMLTraits.h"
24 
25 using namespace clang;
26 using namespace llvm;
27 using clang::include_fixer::IncludeFixerContext;
28 
29 LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(IncludeFixerContext)
30 LLVM_YAML_IS_FLOW_SEQUENCE_VECTOR(IncludeFixerContext::HeaderInfo)
31 LLVM_YAML_IS_FLOW_SEQUENCE_VECTOR(IncludeFixerContext::QuerySymbolInfo)
32 
33 namespace llvm {
34 namespace yaml {
35 
36 template <> struct MappingTraits<tooling::Range> {
37   struct NormalizedRange {
NormalizedRangellvm::yaml::MappingTraits::NormalizedRange38     NormalizedRange(const IO &) : Offset(0), Length(0) {}
39 
NormalizedRangellvm::yaml::MappingTraits::NormalizedRange40     NormalizedRange(const IO &, const tooling::Range &R)
41         : Offset(R.getOffset()), Length(R.getLength()) {}
42 
denormalizellvm::yaml::MappingTraits::NormalizedRange43     tooling::Range denormalize(const IO &) {
44       return tooling::Range(Offset, Length);
45     }
46 
47     unsigned Offset;
48     unsigned Length;
49   };
mappingllvm::yaml::MappingTraits50   static void mapping(IO &IO, tooling::Range &Info) {
51     MappingNormalization<NormalizedRange, tooling::Range> Keys(IO, Info);
52     IO.mapRequired("Offset", Keys->Offset);
53     IO.mapRequired("Length", Keys->Length);
54   }
55 };
56 
57 template <> struct MappingTraits<IncludeFixerContext::HeaderInfo> {
mappingllvm::yaml::MappingTraits58   static void mapping(IO &io, IncludeFixerContext::HeaderInfo &Info) {
59     io.mapRequired("Header", Info.Header);
60     io.mapRequired("QualifiedName", Info.QualifiedName);
61   }
62 };
63 
64 template <> struct MappingTraits<IncludeFixerContext::QuerySymbolInfo> {
mappingllvm::yaml::MappingTraits65   static void mapping(IO &io, IncludeFixerContext::QuerySymbolInfo &Info) {
66     io.mapRequired("RawIdentifier", Info.RawIdentifier);
67     io.mapRequired("Range", Info.Range);
68   }
69 };
70 
71 template <> struct MappingTraits<IncludeFixerContext> {
mappingllvm::yaml::MappingTraits72   static void mapping(IO &IO, IncludeFixerContext &Context) {
73     IO.mapRequired("QuerySymbolInfos", Context.QuerySymbolInfos);
74     IO.mapRequired("HeaderInfos", Context.HeaderInfos);
75     IO.mapRequired("FilePath", Context.FilePath);
76   }
77 };
78 } // namespace yaml
79 } // namespace llvm
80 
81 namespace {
82 cl::OptionCategory IncludeFixerCategory("Tool options");
83 
84 enum DatabaseFormatTy {
85   fixed,     ///< Hard-coded mapping.
86   yaml,      ///< Yaml database created by find-all-symbols.
87   fuzzyYaml, ///< Yaml database with fuzzy-matched identifiers.
88 };
89 
90 cl::opt<DatabaseFormatTy> DatabaseFormat(
91     "db", cl::desc("Specify input format"),
92     cl::values(clEnumVal(fixed, "Hard-coded mapping"),
93                clEnumVal(yaml, "Yaml database created by find-all-symbols"),
94                clEnumVal(fuzzyYaml, "Yaml database, with fuzzy-matched names")),
95     cl::init(yaml), cl::cat(IncludeFixerCategory));
96 
97 cl::opt<std::string> Input("input",
98                            cl::desc("String to initialize the database"),
99                            cl::cat(IncludeFixerCategory));
100 
101 cl::opt<std::string>
102     QuerySymbol("query-symbol",
103                  cl::desc("Query a given symbol (e.g. \"a::b::foo\") in\n"
104                           "database directly without parsing the file."),
105                  cl::cat(IncludeFixerCategory));
106 
107 cl::opt<bool>
108     MinimizeIncludePaths("minimize-paths",
109                          cl::desc("Whether to minimize added include paths"),
110                          cl::init(true), cl::cat(IncludeFixerCategory));
111 
112 cl::opt<bool> Quiet("q", cl::desc("Reduce terminal output"), cl::init(false),
113                     cl::cat(IncludeFixerCategory));
114 
115 cl::opt<bool>
116     STDINMode("stdin",
117               cl::desc("Override source file's content (in the overlaying\n"
118                        "virtual file system) with input from <stdin> and run\n"
119                        "the tool on the new content with the compilation\n"
120                        "options of the source file. This mode is currently\n"
121                        "used for editor integration."),
122               cl::init(false), cl::cat(IncludeFixerCategory));
123 
124 cl::opt<bool> OutputHeaders(
125     "output-headers",
126     cl::desc("Print the symbol being queried and all its relevant headers in\n"
127              "JSON format to stdout:\n"
128              "  {\n"
129              "    \"FilePath\": \"/path/to/foo.cc\",\n"
130              "    \"QuerySymbolInfos\": [\n"
131              "       {\"RawIdentifier\": \"foo\",\n"
132              "        \"Range\": {\"Offset\": 0, \"Length\": 3}}\n"
133              "    ],\n"
134              "    \"HeaderInfos\": [ {\"Header\": \"\\\"foo_a.h\\\"\",\n"
135              "                      \"QualifiedName\": \"a::foo\"} ]\n"
136              "  }"),
137     cl::init(false), cl::cat(IncludeFixerCategory));
138 
139 cl::opt<std::string> InsertHeader(
140     "insert-header",
141     cl::desc("Insert a specific header. This should run with STDIN mode.\n"
142              "The result is written to stdout. It is currently used for\n"
143              "editor integration. Support YAML/JSON format:\n"
144              "  -insert-header=\"{\n"
145              "     FilePath: \"/path/to/foo.cc\",\n"
146              "     QuerySymbolInfos: [\n"
147              "       {RawIdentifier: foo,\n"
148              "        Range: {Offset: 0, Length: 3}}\n"
149              "     ],\n"
150              "     HeaderInfos: [ {Headers: \"\\\"foo_a.h\\\"\",\n"
151              "                     QualifiedName: \"a::foo\"} ]}\""),
152     cl::init(""), cl::cat(IncludeFixerCategory));
153 
154 cl::opt<std::string>
155     Style("style",
156           cl::desc("Fallback style for reformatting after inserting new\n"
157                    "headers if there is no clang-format config file found."),
158           cl::init("llvm"), cl::cat(IncludeFixerCategory));
159 
160 std::unique_ptr<include_fixer::SymbolIndexManager>
createSymbolIndexManager(StringRef FilePath)161 createSymbolIndexManager(StringRef FilePath) {
162   using find_all_symbols::SymbolInfo;
163 
164   auto SymbolIndexMgr = std::make_unique<include_fixer::SymbolIndexManager>();
165   switch (DatabaseFormat) {
166   case fixed: {
167     // Parse input and fill the database with it.
168     // <symbol>=<header><, header...>
169     // Multiple symbols can be given, separated by semicolons.
170     std::map<std::string, std::vector<std::string>> SymbolsMap;
171     SmallVector<StringRef, 4> SemicolonSplits;
172     StringRef(Input).split(SemicolonSplits, ";");
173     std::vector<find_all_symbols::SymbolAndSignals> Symbols;
174     for (StringRef Pair : SemicolonSplits) {
175       auto Split = Pair.split('=');
176       std::vector<std::string> Headers;
177       SmallVector<StringRef, 4> CommaSplits;
178       Split.second.split(CommaSplits, ",");
179       for (size_t I = 0, E = CommaSplits.size(); I != E; ++I)
180         Symbols.push_back(
181             {SymbolInfo(Split.first.trim(), SymbolInfo::SymbolKind::Unknown,
182                         CommaSplits[I].trim(), {}),
183              // Use fake "seen" signal for tests, so first header wins.
184              SymbolInfo::Signals(/*Seen=*/static_cast<unsigned>(E - I),
185                                  /*Used=*/0)});
186     }
187     SymbolIndexMgr->addSymbolIndex([=]() {
188       return std::make_unique<include_fixer::InMemorySymbolIndex>(Symbols);
189     });
190     break;
191   }
192   case yaml: {
193     auto CreateYamlIdx = [=]() -> std::unique_ptr<include_fixer::SymbolIndex> {
194       llvm::ErrorOr<std::unique_ptr<include_fixer::YamlSymbolIndex>> DB(
195           nullptr);
196       if (!Input.empty()) {
197         DB = include_fixer::YamlSymbolIndex::createFromFile(Input);
198       } else {
199         // If we don't have any input file, look in the directory of the
200         // first
201         // file and its parents.
202         SmallString<128> AbsolutePath(tooling::getAbsolutePath(FilePath));
203         StringRef Directory = llvm::sys::path::parent_path(AbsolutePath);
204         DB = include_fixer::YamlSymbolIndex::createFromDirectory(
205             Directory, "find_all_symbols_db.yaml");
206       }
207 
208       if (!DB) {
209         llvm::errs() << "Couldn't find YAML db: " << DB.getError().message()
210                      << '\n';
211         return nullptr;
212       }
213       return std::move(*DB);
214     };
215 
216     SymbolIndexMgr->addSymbolIndex(std::move(CreateYamlIdx));
217     break;
218   }
219   case fuzzyYaml: {
220     // This mode is not very useful, because we don't correct the identifier.
221     // It's main purpose is to expose FuzzySymbolIndex to tests.
222     SymbolIndexMgr->addSymbolIndex(
223         []() -> std::unique_ptr<include_fixer::SymbolIndex> {
224           auto DB = include_fixer::FuzzySymbolIndex::createFromYAML(Input);
225           if (!DB) {
226             llvm::errs() << "Couldn't load fuzzy YAML db: "
227                          << llvm::toString(DB.takeError()) << '\n';
228             return nullptr;
229           }
230           return std::move(*DB);
231         });
232     break;
233   }
234   }
235   return SymbolIndexMgr;
236 }
237 
writeToJson(llvm::raw_ostream & OS,const IncludeFixerContext & Context)238 void writeToJson(llvm::raw_ostream &OS, const IncludeFixerContext& Context) {
239   OS << "{\n"
240      << "  \"FilePath\": \""
241      << llvm::yaml::escape(Context.getFilePath()) << "\",\n"
242      << "  \"QuerySymbolInfos\": [\n";
243   for (const auto &Info : Context.getQuerySymbolInfos()) {
244     OS << "     {\"RawIdentifier\": \"" << Info.RawIdentifier << "\",\n";
245     OS << "      \"Range\":{";
246     OS << "\"Offset\":" << Info.Range.getOffset() << ",";
247     OS << "\"Length\":" << Info.Range.getLength() << "}}";
248     if (&Info != &Context.getQuerySymbolInfos().back())
249       OS << ",\n";
250   }
251   OS << "\n  ],\n";
252   OS << "  \"HeaderInfos\": [\n";
253   const auto &HeaderInfos = Context.getHeaderInfos();
254   for (const auto &Info : HeaderInfos) {
255     OS << "     {\"Header\": \"" << llvm::yaml::escape(Info.Header) << "\",\n"
256        << "      \"QualifiedName\": \"" << Info.QualifiedName << "\"}";
257     if (&Info != &HeaderInfos.back())
258       OS << ",\n";
259   }
260   OS << "\n";
261   OS << "  ]\n";
262   OS << "}\n";
263 }
264 
includeFixerMain(int argc,const char ** argv)265 int includeFixerMain(int argc, const char **argv) {
266   auto ExpectedParser =
267       tooling::CommonOptionsParser::create(argc, argv, IncludeFixerCategory);
268   if (!ExpectedParser) {
269     llvm::errs() << ExpectedParser.takeError();
270     return 1;
271   }
272   tooling::CommonOptionsParser &options = ExpectedParser.get();
273   tooling::ClangTool tool(options.getCompilations(),
274                           options.getSourcePathList());
275 
276   llvm::StringRef SourceFilePath = options.getSourcePathList().front();
277   // In STDINMode, we override the file content with the <stdin> input.
278   // Since `tool.mapVirtualFile` takes `StringRef`, we define `Code` outside of
279   // the if-block so that `Code` is not released after the if-block.
280   std::unique_ptr<llvm::MemoryBuffer> Code;
281   if (STDINMode) {
282     assert(options.getSourcePathList().size() == 1 &&
283            "Expect exactly one file path in STDINMode.");
284     llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> CodeOrErr =
285         MemoryBuffer::getSTDIN();
286     if (std::error_code EC = CodeOrErr.getError()) {
287       errs() << EC.message() << "\n";
288       return 1;
289     }
290     Code = std::move(CodeOrErr.get());
291     if (Code->getBufferSize() == 0)
292       return 0;  // Skip empty files.
293 
294     tool.mapVirtualFile(SourceFilePath, Code->getBuffer());
295   }
296 
297   if (!InsertHeader.empty()) {
298     if (!STDINMode) {
299       errs() << "Should be running in STDIN mode\n";
300       return 1;
301     }
302 
303     llvm::yaml::Input yin(InsertHeader);
304     IncludeFixerContext Context;
305     yin >> Context;
306 
307     const auto &HeaderInfos = Context.getHeaderInfos();
308     assert(!HeaderInfos.empty());
309     // We only accept one unique header.
310     // Check all elements in HeaderInfos have the same header.
311     bool IsUniqueHeader = std::equal(
312         HeaderInfos.begin()+1, HeaderInfos.end(), HeaderInfos.begin(),
313         [](const IncludeFixerContext::HeaderInfo &LHS,
314            const IncludeFixerContext::HeaderInfo &RHS) {
315           return LHS.Header == RHS.Header;
316         });
317     if (!IsUniqueHeader) {
318       errs() << "Expect exactly one unique header.\n";
319       return 1;
320     }
321 
322     // If a header has multiple symbols, we won't add the missing namespace
323     // qualifiers because we don't know which one is exactly used.
324     //
325     // Check whether all elements in HeaderInfos have the same qualified name.
326     bool IsUniqueQualifiedName = std::equal(
327         HeaderInfos.begin() + 1, HeaderInfos.end(), HeaderInfos.begin(),
328         [](const IncludeFixerContext::HeaderInfo &LHS,
329            const IncludeFixerContext::HeaderInfo &RHS) {
330           return LHS.QualifiedName == RHS.QualifiedName;
331         });
332     auto InsertStyle = format::getStyle(format::DefaultFormatStyle,
333                                         Context.getFilePath(), Style);
334     if (!InsertStyle) {
335       llvm::errs() << llvm::toString(InsertStyle.takeError()) << "\n";
336       return 1;
337     }
338     auto Replacements = clang::include_fixer::createIncludeFixerReplacements(
339         Code->getBuffer(), Context, *InsertStyle,
340         /*AddQualifiers=*/IsUniqueQualifiedName);
341     if (!Replacements) {
342       errs() << "Failed to create replacements: "
343              << llvm::toString(Replacements.takeError()) << "\n";
344       return 1;
345     }
346 
347     auto ChangedCode =
348         tooling::applyAllReplacements(Code->getBuffer(), *Replacements);
349     if (!ChangedCode) {
350       llvm::errs() << llvm::toString(ChangedCode.takeError()) << "\n";
351       return 1;
352     }
353     llvm::outs() << *ChangedCode;
354     return 0;
355   }
356 
357   // Set up data source.
358   std::unique_ptr<include_fixer::SymbolIndexManager> SymbolIndexMgr =
359       createSymbolIndexManager(SourceFilePath);
360   if (!SymbolIndexMgr)
361     return 1;
362 
363   // Query symbol mode.
364   if (!QuerySymbol.empty()) {
365     auto MatchedSymbols = SymbolIndexMgr->search(
366         QuerySymbol, /*IsNestedSearch=*/true, SourceFilePath);
367     for (auto &Symbol : MatchedSymbols) {
368       std::string HeaderPath = Symbol.getFilePath().str();
369       Symbol.SetFilePath(((HeaderPath[0] == '"' || HeaderPath[0] == '<')
370                               ? HeaderPath
371                               : "\"" + HeaderPath + "\""));
372     }
373 
374     // We leave an empty symbol range as we don't know the range of the symbol
375     // being queried in this mode. clang-include-fixer won't add namespace
376     // qualifiers if the symbol range is empty, which also fits this case.
377     IncludeFixerContext::QuerySymbolInfo Symbol;
378     Symbol.RawIdentifier = QuerySymbol;
379     auto Context =
380         IncludeFixerContext(SourceFilePath, {Symbol}, MatchedSymbols);
381     writeToJson(llvm::outs(), Context);
382     return 0;
383   }
384 
385   // Now run our tool.
386   std::vector<include_fixer::IncludeFixerContext> Contexts;
387   include_fixer::IncludeFixerActionFactory Factory(*SymbolIndexMgr, Contexts,
388                                                    Style, MinimizeIncludePaths);
389 
390   if (tool.run(&Factory) != 0) {
391     // We suppress all Clang diagnostics (because they would be wrong,
392     // clang-include-fixer does custom recovery) but still want to give some
393     // feedback in case there was a compiler error we couldn't recover from.
394     // The most common case for this is a #include in the file that couldn't be
395     // found.
396     llvm::errs() << "Fatal compiler error occurred while parsing file!"
397                     " (incorrect include paths?)\n";
398     return 1;
399   }
400 
401   assert(!Contexts.empty());
402 
403   if (OutputHeaders) {
404     // FIXME: Print contexts of all processing files instead of the first one.
405     writeToJson(llvm::outs(), Contexts.front());
406     return 0;
407   }
408 
409   std::vector<tooling::Replacements> FixerReplacements;
410   for (const auto &Context : Contexts) {
411     StringRef FilePath = Context.getFilePath();
412     auto InsertStyle =
413         format::getStyle(format::DefaultFormatStyle, FilePath, Style);
414     if (!InsertStyle) {
415       llvm::errs() << llvm::toString(InsertStyle.takeError()) << "\n";
416       return 1;
417     }
418     auto Buffer = llvm::MemoryBuffer::getFile(FilePath);
419     if (!Buffer) {
420       errs() << "Couldn't open file: " + FilePath.str() + ": "
421              << Buffer.getError().message() + "\n";
422       return 1;
423     }
424 
425     auto Replacements = clang::include_fixer::createIncludeFixerReplacements(
426         Buffer.get()->getBuffer(), Context, *InsertStyle);
427     if (!Replacements) {
428       errs() << "Failed to create replacement: "
429              << llvm::toString(Replacements.takeError()) << "\n";
430       return 1;
431     }
432     FixerReplacements.push_back(*Replacements);
433   }
434 
435   if (!Quiet) {
436     for (const auto &Context : Contexts) {
437       if (!Context.getHeaderInfos().empty()) {
438         llvm::errs() << "Added #include "
439                      << Context.getHeaderInfos().front().Header << " for "
440                      << Context.getFilePath() << "\n";
441       }
442     }
443   }
444 
445   if (STDINMode) {
446     assert(FixerReplacements.size() == 1);
447     auto ChangedCode = tooling::applyAllReplacements(Code->getBuffer(),
448                                                      FixerReplacements.front());
449     if (!ChangedCode) {
450       llvm::errs() << llvm::toString(ChangedCode.takeError()) << "\n";
451       return 1;
452     }
453     llvm::outs() << *ChangedCode;
454     return 0;
455   }
456 
457   // Set up a new source manager for applying the resulting replacements.
458   IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts(new DiagnosticOptions);
459   DiagnosticsEngine Diagnostics(new DiagnosticIDs, &*DiagOpts);
460   TextDiagnosticPrinter DiagnosticPrinter(outs(), &*DiagOpts);
461   SourceManager SM(Diagnostics, tool.getFiles());
462   Diagnostics.setClient(&DiagnosticPrinter, false);
463 
464   // Write replacements to disk.
465   Rewriter Rewrites(SM, LangOptions());
466   for (const auto &Replacement : FixerReplacements) {
467     if (!tooling::applyAllReplacements(Replacement, Rewrites)) {
468       llvm::errs() << "Failed to apply replacements.\n";
469       return 1;
470     }
471   }
472   return Rewrites.overwriteChangedFiles();
473 }
474 
475 } // namespace
476 
main(int argc,const char ** argv)477 int main(int argc, const char **argv) {
478   return includeFixerMain(argc, argv);
479 }
480