1 //===--- SourceExtraction.cpp - Clang refactoring library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
10 #include "clang/AST/Stmt.h"
11 #include "clang/AST/StmtCXX.h"
12 #include "clang/AST/StmtObjC.h"
13 #include "clang/Basic/SourceManager.h"
14 #include "clang/Lex/Lexer.h"
15 #include <optional>
16 
17 using namespace clang;
18 
19 namespace {
20 
21 /// Returns true if the token at the given location is a semicolon.
22 bool isSemicolonAtLocation(SourceLocation TokenLoc, const SourceManager &SM,
23                            const LangOptions &LangOpts) {
24   return Lexer::getSourceText(
25              CharSourceRange::getTokenRange(TokenLoc, TokenLoc), SM,
26              LangOpts) == ";";
27 }
28 
29 /// Returns true if there should be a semicolon after the given statement.
30 bool isSemicolonRequiredAfter(const Stmt *S) {
31   if (isa<CompoundStmt>(S))
32     return false;
33   if (const auto *If = dyn_cast<IfStmt>(S))
34     return isSemicolonRequiredAfter(If->getElse() ? If->getElse()
35                                                   : If->getThen());
36   if (const auto *While = dyn_cast<WhileStmt>(S))
37     return isSemicolonRequiredAfter(While->getBody());
38   if (const auto *For = dyn_cast<ForStmt>(S))
39     return isSemicolonRequiredAfter(For->getBody());
40   if (const auto *CXXFor = dyn_cast<CXXForRangeStmt>(S))
41     return isSemicolonRequiredAfter(CXXFor->getBody());
42   if (const auto *ObjCFor = dyn_cast<ObjCForCollectionStmt>(S))
43     return isSemicolonRequiredAfter(ObjCFor->getBody());
44   if(const auto *Switch = dyn_cast<SwitchStmt>(S))
45     return isSemicolonRequiredAfter(Switch->getBody());
46   if(const auto *Case = dyn_cast<SwitchCase>(S))
47     return isSemicolonRequiredAfter(Case->getSubStmt());
48   switch (S->getStmtClass()) {
49   case Stmt::DeclStmtClass:
50   case Stmt::CXXTryStmtClass:
51   case Stmt::ObjCAtSynchronizedStmtClass:
52   case Stmt::ObjCAutoreleasePoolStmtClass:
53   case Stmt::ObjCAtTryStmtClass:
54     return false;
55   default:
56     return true;
57   }
58 }
59 
60 /// Returns true if the two source locations are on the same line.
61 bool areOnSameLine(SourceLocation Loc1, SourceLocation Loc2,
62                    const SourceManager &SM) {
63   return !Loc1.isMacroID() && !Loc2.isMacroID() &&
64          SM.getSpellingLineNumber(Loc1) == SM.getSpellingLineNumber(Loc2);
65 }
66 
67 } // end anonymous namespace
68 
69 namespace clang {
70 namespace tooling {
71 
72 ExtractionSemicolonPolicy
73 ExtractionSemicolonPolicy::compute(const Stmt *S, SourceRange &ExtractedRange,
74                                    const SourceManager &SM,
75                                    const LangOptions &LangOpts) {
76   auto neededInExtractedFunction = []() {
77     return ExtractionSemicolonPolicy(true, false);
78   };
79   auto neededInOriginalFunction = []() {
80     return ExtractionSemicolonPolicy(false, true);
81   };
82 
83   /// The extracted expression should be terminated with a ';'. The call to
84   /// the extracted function will replace this expression, so it won't need
85   /// a terminating ';'.
86   if (isa<Expr>(S))
87     return neededInExtractedFunction();
88 
89   /// Some statements don't need to be terminated with ';'. The call to the
90   /// extracted function will be a standalone statement, so it should be
91   /// terminated with a ';'.
92   bool NeedsSemi = isSemicolonRequiredAfter(S);
93   if (!NeedsSemi)
94     return neededInOriginalFunction();
95 
96   /// Some statements might end at ';'. The extraction will move that ';', so
97   /// the call to the extracted function should be terminated with a ';'.
98   SourceLocation End = ExtractedRange.getEnd();
99   if (isSemicolonAtLocation(End, SM, LangOpts))
100     return neededInOriginalFunction();
101 
102   /// Other statements should generally have a trailing ';'. We can try to find
103   /// it and move it together it with the extracted code.
104   std::optional<Token> NextToken = Lexer::findNextToken(End, SM, LangOpts);
105   if (NextToken && NextToken->is(tok::semi) &&
106       areOnSameLine(NextToken->getLocation(), End, SM)) {
107     ExtractedRange.setEnd(NextToken->getLocation());
108     return neededInOriginalFunction();
109   }
110 
111   /// Otherwise insert semicolons in both places.
112   return ExtractionSemicolonPolicy(true, true);
113 }
114 
115 } // end namespace tooling
116 } // end namespace clang
117