1 //===--- ExtractFunction.cpp -------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Extracts statements to a new function and replaces the statements with a
10 // call to the new function.
11 // Before:
12 //   void f(int a) {
13 //     [[if(a < 5)
14 //       a = 5;]]
15 //   }
16 // After:
17 //   void extracted(int &a) {
18 //     if(a < 5)
19 //       a = 5;
20 //   }
21 //   void f(int a) {
22 //     extracted(a);
23 //   }
24 //
25 // - Only extract statements
26 // - Extracts from non-templated free functions only.
27 // - Parameters are const only if the declaration was const
28 //   - Always passed by l-value reference
29 // - Void return type
30 // - Cannot extract declarations that will be needed in the original function
31 //   after extraction.
32 // - Checks for broken control flow (break/continue without loop/switch)
33 //
34 // 1. ExtractFunction is the tweak subclass
35 //    - Prepare does basic analysis of the selection and is therefore fast.
36 //      Successful prepare doesn't always mean we can apply the tweak.
37 //    - Apply does a more detailed analysis and can be slower. In case of
38 //      failure, we let the user know that we are unable to perform extraction.
39 // 2. ExtractionZone store information about the range being extracted and the
40 //    enclosing function.
41 // 3. NewFunction stores properties of the extracted function and provides
42 //    methods for rendering it.
43 // 4. CapturedZoneInfo uses a RecursiveASTVisitor to capture information about
44 //    the extraction like declarations, existing return statements, etc.
45 // 5. getExtractedFunction is responsible for analyzing the CapturedZoneInfo and
46 //    creating a NewFunction.
47 //===----------------------------------------------------------------------===//
48 
49 #include "AST.h"
50 #include "ParsedAST.h"
51 #include "Selection.h"
52 #include "SourceCode.h"
53 #include "refactor/Tweak.h"
54 #include "support/Logger.h"
55 #include "clang/AST/ASTContext.h"
56 #include "clang/AST/Decl.h"
57 #include "clang/AST/DeclTemplate.h"
58 #include "clang/AST/RecursiveASTVisitor.h"
59 #include "clang/AST/Stmt.h"
60 #include "clang/Basic/LangOptions.h"
61 #include "clang/Basic/SourceLocation.h"
62 #include "clang/Basic/SourceManager.h"
63 #include "clang/Lex/Lexer.h"
64 #include "clang/Tooling/Core/Replacement.h"
65 #include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
66 #include "llvm/ADT/None.h"
67 #include "llvm/ADT/Optional.h"
68 #include "llvm/ADT/SmallVector.h"
69 #include "llvm/ADT/StringRef.h"
70 #include "llvm/ADT/iterator_range.h"
71 #include "llvm/Support/Casting.h"
72 #include "llvm/Support/Error.h"
73 
74 namespace clang {
75 namespace clangd {
76 namespace {
77 
78 using Node = SelectionTree::Node;
79 
80 // ExtractionZone is the part of code that is being extracted.
81 // EnclosingFunction is the function/method inside which the zone lies.
82 // We split the file into 4 parts relative to extraction zone.
83 enum class ZoneRelative {
84   Before,     // Before Zone and inside EnclosingFunction.
85   Inside,     // Inside Zone.
86   After,      // After Zone and inside EnclosingFunction.
87   OutsideFunc // Outside EnclosingFunction.
88 };
89 
90 // A RootStmt is a statement that's fully selected including all it's children
91 // and it's parent is unselected.
92 // Check if a node is a root statement.
isRootStmt(const Node * N)93 bool isRootStmt(const Node *N) {
94   if (!N->ASTNode.get<Stmt>())
95     return false;
96   // Root statement cannot be partially selected.
97   if (N->Selected == SelectionTree::Partial)
98     return false;
99   // Only DeclStmt can be an unselected RootStmt since VarDecls claim the entire
100   // selection range in selectionTree.
101   if (N->Selected == SelectionTree::Unselected && !N->ASTNode.get<DeclStmt>())
102     return false;
103   return true;
104 }
105 
106 // Returns the (unselected) parent of all RootStmts given the commonAncestor.
107 // Returns null if:
108 // 1. any node is partially selected
109 // 2. If all completely selected nodes don't have the same common parent
110 // 3. Any child of Parent isn't a RootStmt.
111 // Returns null if any child is not a RootStmt.
112 // We only support extraction of RootStmts since it allows us to extract without
113 // having to change the selection range. Also, this means that any scope that
114 // begins in selection range, ends in selection range and any scope that begins
115 // outside the selection range, ends outside as well.
getParentOfRootStmts(const Node * CommonAnc)116 const Node *getParentOfRootStmts(const Node *CommonAnc) {
117   if (!CommonAnc)
118     return nullptr;
119   const Node *Parent = nullptr;
120   switch (CommonAnc->Selected) {
121   case SelectionTree::Selection::Unselected:
122     // Typically a block, with the { and } unselected, could also be ForStmt etc
123     // Ensure all Children are RootStmts.
124     Parent = CommonAnc;
125     break;
126   case SelectionTree::Selection::Partial:
127     // Only a fully-selected single statement can be selected.
128     return nullptr;
129   case SelectionTree::Selection::Complete:
130     // If the Common Ancestor is completely selected, then it's a root statement
131     // and its parent will be unselected.
132     Parent = CommonAnc->Parent;
133     // If parent is a DeclStmt, even though it's unselected, we consider it a
134     // root statement and return its parent. This is done because the VarDecls
135     // claim the entire selection range of the Declaration and DeclStmt is
136     // always unselected.
137     if (Parent->ASTNode.get<DeclStmt>())
138       Parent = Parent->Parent;
139     break;
140   }
141   // Ensure all Children are RootStmts.
142   return llvm::all_of(Parent->Children, isRootStmt) ? Parent : nullptr;
143 }
144 
145 // The ExtractionZone class forms a view of the code wrt Zone.
146 struct ExtractionZone {
147   // Parent of RootStatements being extracted.
148   const Node *Parent = nullptr;
149   // The half-open file range of the code being extracted.
150   SourceRange ZoneRange;
151   // The function inside which our zone resides.
152   const FunctionDecl *EnclosingFunction = nullptr;
153   // The half-open file range of the enclosing function.
154   SourceRange EnclosingFuncRange;
getInsertionPointclang::clangd::__anon736e3b0d0111::ExtractionZone155   SourceLocation getInsertionPoint() const {
156     return EnclosingFuncRange.getBegin();
157   }
158   bool isRootStmt(const Stmt *S) const;
159   // The last root statement is important to decide where we need to insert a
160   // semicolon after the extraction.
getLastRootStmtclang::clangd::__anon736e3b0d0111::ExtractionZone161   const Node *getLastRootStmt() const { return Parent->Children.back(); }
162   void generateRootStmts();
163 
164 private:
165   llvm::DenseSet<const Stmt *> RootStmts;
166 };
167 
168 // Whether the code in the extraction zone is guaranteed to return, assuming
169 // no broken control flow (unbound break/continue).
170 // This is a very naive check (does it end with a return stmt).
171 // Doing some rudimentary control flow analysis would cover more cases.
alwaysReturns(const ExtractionZone & EZ)172 bool alwaysReturns(const ExtractionZone &EZ) {
173   const Stmt *Last = EZ.getLastRootStmt()->ASTNode.get<Stmt>();
174   // Unwrap enclosing (unconditional) compound statement.
175   while (const auto *CS = llvm::dyn_cast<CompoundStmt>(Last)) {
176     if (CS->body_empty())
177       return false;
178     else
179       Last = CS->body_back();
180   }
181   return llvm::isa<ReturnStmt>(Last);
182 }
183 
isRootStmt(const Stmt * S) const184 bool ExtractionZone::isRootStmt(const Stmt *S) const {
185   return RootStmts.find(S) != RootStmts.end();
186 }
187 
188 // Generate RootStmts set
generateRootStmts()189 void ExtractionZone::generateRootStmts() {
190   for (const Node *Child : Parent->Children)
191     RootStmts.insert(Child->ASTNode.get<Stmt>());
192 }
193 
194 // Finds the function in which the zone lies.
findEnclosingFunction(const Node * CommonAnc)195 const FunctionDecl *findEnclosingFunction(const Node *CommonAnc) {
196   // Walk up the SelectionTree until we find a function Decl
197   for (const Node *CurNode = CommonAnc; CurNode; CurNode = CurNode->Parent) {
198     // Don't extract from lambdas
199     if (CurNode->ASTNode.get<LambdaExpr>())
200       return nullptr;
201     if (const FunctionDecl *Func = CurNode->ASTNode.get<FunctionDecl>()) {
202       // FIXME: Support extraction from methods.
203       if (isa<CXXMethodDecl>(Func))
204         return nullptr;
205       // FIXME: Support extraction from templated functions.
206       if (Func->isTemplated())
207         return nullptr;
208       return Func;
209     }
210   }
211   return nullptr;
212 }
213 
214 // Zone Range is the union of SourceRanges of all child Nodes in Parent since
215 // all child Nodes are RootStmts
findZoneRange(const Node * Parent,const SourceManager & SM,const LangOptions & LangOpts)216 llvm::Optional<SourceRange> findZoneRange(const Node *Parent,
217                                           const SourceManager &SM,
218                                           const LangOptions &LangOpts) {
219   SourceRange SR;
220   if (auto BeginFileRange = toHalfOpenFileRange(
221           SM, LangOpts, Parent->Children.front()->ASTNode.getSourceRange()))
222     SR.setBegin(BeginFileRange->getBegin());
223   else
224     return llvm::None;
225   if (auto EndFileRange = toHalfOpenFileRange(
226           SM, LangOpts, Parent->Children.back()->ASTNode.getSourceRange()))
227     SR.setEnd(EndFileRange->getEnd());
228   else
229     return llvm::None;
230   return SR;
231 }
232 
233 // Compute the range spanned by the enclosing function.
234 // FIXME: check if EnclosingFunction has any attributes as the AST doesn't
235 // always store the source range of the attributes and thus we end up extracting
236 // between the attributes and the EnclosingFunction.
237 llvm::Optional<SourceRange>
computeEnclosingFuncRange(const FunctionDecl * EnclosingFunction,const SourceManager & SM,const LangOptions & LangOpts)238 computeEnclosingFuncRange(const FunctionDecl *EnclosingFunction,
239                           const SourceManager &SM,
240                           const LangOptions &LangOpts) {
241   return toHalfOpenFileRange(SM, LangOpts, EnclosingFunction->getSourceRange());
242 }
243 
244 // returns true if Child can be a single RootStmt being extracted from
245 // EnclosingFunc.
validSingleChild(const Node * Child,const FunctionDecl * EnclosingFunc)246 bool validSingleChild(const Node *Child, const FunctionDecl *EnclosingFunc) {
247   // Don't extract expressions.
248   // FIXME: We should extract expressions that are "statements" i.e. not
249   // subexpressions
250   if (Child->ASTNode.get<Expr>())
251     return false;
252   // Extracting the body of EnclosingFunc would remove it's definition.
253   assert(EnclosingFunc->hasBody() &&
254          "We should always be extracting from a function body.");
255   if (Child->ASTNode.get<Stmt>() == EnclosingFunc->getBody())
256     return false;
257   return true;
258 }
259 
260 // FIXME: Check we're not extracting from the initializer/condition of a control
261 // flow structure.
findExtractionZone(const Node * CommonAnc,const SourceManager & SM,const LangOptions & LangOpts)262 llvm::Optional<ExtractionZone> findExtractionZone(const Node *CommonAnc,
263                                                   const SourceManager &SM,
264                                                   const LangOptions &LangOpts) {
265   ExtractionZone ExtZone;
266   ExtZone.Parent = getParentOfRootStmts(CommonAnc);
267   if (!ExtZone.Parent || ExtZone.Parent->Children.empty())
268     return llvm::None;
269   ExtZone.EnclosingFunction = findEnclosingFunction(ExtZone.Parent);
270   if (!ExtZone.EnclosingFunction)
271     return llvm::None;
272   // When there is a single RootStmt, we must check if it's valid for
273   // extraction.
274   if (ExtZone.Parent->Children.size() == 1 &&
275       !validSingleChild(ExtZone.getLastRootStmt(), ExtZone.EnclosingFunction))
276     return llvm::None;
277   if (auto FuncRange =
278           computeEnclosingFuncRange(ExtZone.EnclosingFunction, SM, LangOpts))
279     ExtZone.EnclosingFuncRange = *FuncRange;
280   if (auto ZoneRange = findZoneRange(ExtZone.Parent, SM, LangOpts))
281     ExtZone.ZoneRange = *ZoneRange;
282   if (ExtZone.EnclosingFuncRange.isInvalid() || ExtZone.ZoneRange.isInvalid())
283     return llvm::None;
284   ExtZone.generateRootStmts();
285   return ExtZone;
286 }
287 
288 // Stores information about the extracted function and provides methods for
289 // rendering it.
290 struct NewFunction {
291   struct Parameter {
292     std::string Name;
293     QualType TypeInfo;
294     bool PassByReference;
295     unsigned OrderPriority; // Lower value parameters are preferred first.
296     std::string render(const DeclContext *Context) const;
operator <clang::clangd::__anon736e3b0d0111::NewFunction::Parameter297     bool operator<(const Parameter &Other) const {
298       return OrderPriority < Other.OrderPriority;
299     }
300   };
301   std::string Name = "extracted";
302   QualType ReturnType;
303   std::vector<Parameter> Parameters;
304   SourceRange BodyRange;
305   SourceLocation InsertionPoint;
306   const DeclContext *EnclosingFuncContext;
307   bool CallerReturnsValue = false;
308   // Decides whether the extracted function body and the function call need a
309   // semicolon after extraction.
310   tooling::ExtractionSemicolonPolicy SemicolonPolicy;
NewFunctionclang::clangd::__anon736e3b0d0111::NewFunction311   NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy)
312       : SemicolonPolicy(SemicolonPolicy) {}
313   // Render the call for this function.
314   std::string renderCall() const;
315   // Render the definition for this function.
316   std::string renderDefinition(const SourceManager &SM) const;
317 
318 private:
319   std::string renderParametersForDefinition() const;
320   std::string renderParametersForCall() const;
321   // Generate the function body.
322   std::string getFuncBody(const SourceManager &SM) const;
323 };
324 
renderParametersForDefinition() const325 std::string NewFunction::renderParametersForDefinition() const {
326   std::string Result;
327   bool NeedCommaBefore = false;
328   for (const Parameter &P : Parameters) {
329     if (NeedCommaBefore)
330       Result += ", ";
331     NeedCommaBefore = true;
332     Result += P.render(EnclosingFuncContext);
333   }
334   return Result;
335 }
336 
renderParametersForCall() const337 std::string NewFunction::renderParametersForCall() const {
338   std::string Result;
339   bool NeedCommaBefore = false;
340   for (const Parameter &P : Parameters) {
341     if (NeedCommaBefore)
342       Result += ", ";
343     NeedCommaBefore = true;
344     Result += P.Name;
345   }
346   return Result;
347 }
348 
renderCall() const349 std::string NewFunction::renderCall() const {
350   return std::string(
351       llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name,
352                     renderParametersForCall(),
353                     (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : "")));
354 }
355 
renderDefinition(const SourceManager & SM) const356 std::string NewFunction::renderDefinition(const SourceManager &SM) const {
357   return std::string(llvm::formatv(
358       "{0} {1}({2}) {\n{3}\n}\n", printType(ReturnType, *EnclosingFuncContext),
359       Name, renderParametersForDefinition(), getFuncBody(SM)));
360 }
361 
getFuncBody(const SourceManager & SM) const362 std::string NewFunction::getFuncBody(const SourceManager &SM) const {
363   // FIXME: Generate tooling::Replacements instead of std::string to
364   // - hoist decls
365   // - add return statement
366   // - Add semicolon
367   return toSourceCode(SM, BodyRange).str() +
368          (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
369 }
370 
render(const DeclContext * Context) const371 std::string NewFunction::Parameter::render(const DeclContext *Context) const {
372   return printType(TypeInfo, *Context) + (PassByReference ? " &" : " ") + Name;
373 }
374 
375 // Stores captured information about Extraction Zone.
376 struct CapturedZoneInfo {
377   struct DeclInformation {
378     const Decl *TheDecl;
379     ZoneRelative DeclaredIn;
380     // index of the declaration or first reference.
381     unsigned DeclIndex;
382     bool IsReferencedInZone = false;
383     bool IsReferencedInPostZone = false;
384     // FIXME: Capture mutation information
DeclInformationclang::clangd::__anon736e3b0d0111::CapturedZoneInfo::DeclInformation385     DeclInformation(const Decl *TheDecl, ZoneRelative DeclaredIn,
386                     unsigned DeclIndex)
387         : TheDecl(TheDecl), DeclaredIn(DeclaredIn), DeclIndex(DeclIndex){};
388     // Marks the occurence of a reference for this declaration
389     void markOccurence(ZoneRelative ReferenceLoc);
390   };
391   // Maps Decls to their DeclInfo
392   llvm::DenseMap<const Decl *, DeclInformation> DeclInfoMap;
393   bool HasReturnStmt = false; // Are there any return statements in the zone?
394   bool AlwaysReturns = false; // Does the zone always return?
395   // Control flow is broken if we are extracting a break/continue without a
396   // corresponding parent loop/switch
397   bool BrokenControlFlow = false;
398   // FIXME: capture TypeAliasDecl and UsingDirectiveDecl
399   // FIXME: Capture type information as well.
400   DeclInformation *createDeclInfo(const Decl *D, ZoneRelative RelativeLoc);
401   DeclInformation *getDeclInfoFor(const Decl *D);
402 };
403 
404 CapturedZoneInfo::DeclInformation *
createDeclInfo(const Decl * D,ZoneRelative RelativeLoc)405 CapturedZoneInfo::createDeclInfo(const Decl *D, ZoneRelative RelativeLoc) {
406   // The new Decl's index is the size of the map so far.
407   auto InsertionResult = DeclInfoMap.insert(
408       {D, DeclInformation(D, RelativeLoc, DeclInfoMap.size())});
409   // Return the newly created DeclInfo
410   return &InsertionResult.first->second;
411 }
412 
413 CapturedZoneInfo::DeclInformation *
getDeclInfoFor(const Decl * D)414 CapturedZoneInfo::getDeclInfoFor(const Decl *D) {
415   // If the Decl doesn't exist, we
416   auto Iter = DeclInfoMap.find(D);
417   if (Iter == DeclInfoMap.end())
418     return nullptr;
419   return &Iter->second;
420 }
421 
markOccurence(ZoneRelative ReferenceLoc)422 void CapturedZoneInfo::DeclInformation::markOccurence(
423     ZoneRelative ReferenceLoc) {
424   switch (ReferenceLoc) {
425   case ZoneRelative::Inside:
426     IsReferencedInZone = true;
427     break;
428   case ZoneRelative::After:
429     IsReferencedInPostZone = true;
430     break;
431   default:
432     break;
433   }
434 }
435 
isLoop(const Stmt * S)436 bool isLoop(const Stmt *S) {
437   return isa<ForStmt>(S) || isa<DoStmt>(S) || isa<WhileStmt>(S) ||
438          isa<CXXForRangeStmt>(S);
439 }
440 
441 // Captures information from Extraction Zone
captureZoneInfo(const ExtractionZone & ExtZone)442 CapturedZoneInfo captureZoneInfo(const ExtractionZone &ExtZone) {
443   // We use the ASTVisitor instead of using the selection tree since we need to
444   // find references in the PostZone as well.
445   // FIXME: Check which statements we don't allow to extract.
446   class ExtractionZoneVisitor
447       : public clang::RecursiveASTVisitor<ExtractionZoneVisitor> {
448   public:
449     ExtractionZoneVisitor(const ExtractionZone &ExtZone) : ExtZone(ExtZone) {
450       TraverseDecl(const_cast<FunctionDecl *>(ExtZone.EnclosingFunction));
451     }
452 
453     bool TraverseStmt(Stmt *S) {
454       if (!S)
455         return true;
456       bool IsRootStmt = ExtZone.isRootStmt(const_cast<const Stmt *>(S));
457       // If we are starting traversal of a RootStmt, we are somewhere inside
458       // ExtractionZone
459       if (IsRootStmt)
460         CurrentLocation = ZoneRelative::Inside;
461       addToLoopSwitchCounters(S, 1);
462       // Traverse using base class's TraverseStmt
463       RecursiveASTVisitor::TraverseStmt(S);
464       addToLoopSwitchCounters(S, -1);
465       // We set the current location as after since next stmt will either be a
466       // RootStmt (handled at the beginning) or after extractionZone
467       if (IsRootStmt)
468         CurrentLocation = ZoneRelative::After;
469       return true;
470     }
471 
472     // Add Increment to CurNumberOf{Loops,Switch} if statement is
473     // {Loop,Switch} and inside Extraction Zone.
474     void addToLoopSwitchCounters(Stmt *S, int Increment) {
475       if (CurrentLocation != ZoneRelative::Inside)
476         return;
477       if (isLoop(S))
478         CurNumberOfNestedLoops += Increment;
479       else if (isa<SwitchStmt>(S))
480         CurNumberOfSwitch += Increment;
481     }
482 
483     // Decrement CurNumberOf{NestedLoops,Switch} if statement is {Loop,Switch}
484     // and inside Extraction Zone.
485     void decrementLoopSwitchCounters(Stmt *S) {
486       if (CurrentLocation != ZoneRelative::Inside)
487         return;
488       if (isLoop(S))
489         CurNumberOfNestedLoops--;
490       else if (isa<SwitchStmt>(S))
491         CurNumberOfSwitch--;
492     }
493 
494     bool VisitDecl(Decl *D) {
495       Info.createDeclInfo(D, CurrentLocation);
496       return true;
497     }
498 
499     bool VisitDeclRefExpr(DeclRefExpr *DRE) {
500       // Find the corresponding Decl and mark it's occurrence.
501       const Decl *D = DRE->getDecl();
502       auto *DeclInfo = Info.getDeclInfoFor(D);
503       // If no Decl was found, the Decl must be outside the enclosingFunc.
504       if (!DeclInfo)
505         DeclInfo = Info.createDeclInfo(D, ZoneRelative::OutsideFunc);
506       DeclInfo->markOccurence(CurrentLocation);
507       // FIXME: check if reference mutates the Decl being referred.
508       return true;
509     }
510 
511     bool VisitReturnStmt(ReturnStmt *Return) {
512       if (CurrentLocation == ZoneRelative::Inside)
513         Info.HasReturnStmt = true;
514       return true;
515     }
516 
517     bool VisitBreakStmt(BreakStmt *Break) {
518       // Control flow is broken if break statement is selected without any
519       // parent loop or switch statement.
520       if (CurrentLocation == ZoneRelative::Inside &&
521           !(CurNumberOfNestedLoops || CurNumberOfSwitch))
522         Info.BrokenControlFlow = true;
523       return true;
524     }
525 
526     bool VisitContinueStmt(ContinueStmt *Continue) {
527       // Control flow is broken if Continue statement is selected without any
528       // parent loop
529       if (CurrentLocation == ZoneRelative::Inside && !CurNumberOfNestedLoops)
530         Info.BrokenControlFlow = true;
531       return true;
532     }
533     CapturedZoneInfo Info;
534     const ExtractionZone &ExtZone;
535     ZoneRelative CurrentLocation = ZoneRelative::Before;
536     // Number of {loop,switch} statements that are currently in the traversal
537     // stack inside Extraction Zone. Used to check for broken control flow.
538     unsigned CurNumberOfNestedLoops = 0;
539     unsigned CurNumberOfSwitch = 0;
540   };
541   ExtractionZoneVisitor Visitor(ExtZone);
542   CapturedZoneInfo Result = std::move(Visitor.Info);
543   Result.AlwaysReturns = alwaysReturns(ExtZone);
544   return Result;
545 }
546 
547 // Adds parameters to ExtractedFunc.
548 // Returns true if able to find the parameters successfully and no hoisting
549 // needed.
550 // FIXME: Check if the declaration has a local/anonymous type
createParameters(NewFunction & ExtractedFunc,const CapturedZoneInfo & CapturedInfo)551 bool createParameters(NewFunction &ExtractedFunc,
552                       const CapturedZoneInfo &CapturedInfo) {
553   for (const auto &KeyVal : CapturedInfo.DeclInfoMap) {
554     const auto &DeclInfo = KeyVal.second;
555     // If a Decl was Declared in zone and referenced in post zone, it
556     // needs to be hoisted (we bail out in that case).
557     // FIXME: Support Decl Hoisting.
558     if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
559         DeclInfo.IsReferencedInPostZone)
560       return false;
561     if (!DeclInfo.IsReferencedInZone)
562       continue; // no need to pass as parameter, not referenced
563     if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
564         DeclInfo.DeclaredIn == ZoneRelative::OutsideFunc)
565       continue; // no need to pass as parameter, still accessible.
566     // Parameter specific checks.
567     const ValueDecl *VD = dyn_cast_or_null<ValueDecl>(DeclInfo.TheDecl);
568     // Can't parameterise if the Decl isn't a ValueDecl or is a FunctionDecl
569     // (this includes the case of recursive call to EnclosingFunc in Zone).
570     if (!VD || isa<FunctionDecl>(DeclInfo.TheDecl))
571       return false;
572     // Parameter qualifiers are same as the Decl's qualifiers.
573     QualType TypeInfo = VD->getType().getNonReferenceType();
574     // FIXME: Need better qualifier checks: check mutated status for
575     // Decl(e.g. was it assigned, passed as nonconst argument, etc)
576     // FIXME: check if parameter will be a non l-value reference.
577     // FIXME: We don't want to always pass variables of types like int,
578     // pointers, etc by reference.
579     bool IsPassedByReference = true;
580     // We use the index of declaration as the ordering priority for parameters.
581     ExtractedFunc.Parameters.push_back({std::string(VD->getName()), TypeInfo,
582                                         IsPassedByReference,
583                                         DeclInfo.DeclIndex});
584   }
585   llvm::sort(ExtractedFunc.Parameters);
586   return true;
587 }
588 
589 // Clangd uses open ranges while ExtractionSemicolonPolicy (in Clang Tooling)
590 // uses closed ranges. Generates the semicolon policy for the extraction and
591 // extends the ZoneRange if necessary.
592 tooling::ExtractionSemicolonPolicy
getSemicolonPolicy(ExtractionZone & ExtZone,const SourceManager & SM,const LangOptions & LangOpts)593 getSemicolonPolicy(ExtractionZone &ExtZone, const SourceManager &SM,
594                    const LangOptions &LangOpts) {
595   // Get closed ZoneRange.
596   SourceRange FuncBodyRange = {ExtZone.ZoneRange.getBegin(),
597                                ExtZone.ZoneRange.getEnd().getLocWithOffset(-1)};
598   auto SemicolonPolicy = tooling::ExtractionSemicolonPolicy::compute(
599       ExtZone.getLastRootStmt()->ASTNode.get<Stmt>(), FuncBodyRange, SM,
600       LangOpts);
601   // Update ZoneRange.
602   ExtZone.ZoneRange.setEnd(FuncBodyRange.getEnd().getLocWithOffset(1));
603   return SemicolonPolicy;
604 }
605 
606 // Generate return type for ExtractedFunc. Return false if unable to do so.
generateReturnProperties(NewFunction & ExtractedFunc,const FunctionDecl & EnclosingFunc,const CapturedZoneInfo & CapturedInfo)607 bool generateReturnProperties(NewFunction &ExtractedFunc,
608                               const FunctionDecl &EnclosingFunc,
609                               const CapturedZoneInfo &CapturedInfo) {
610   // If the selected code always returns, we preserve those return statements.
611   // The return type should be the same as the enclosing function.
612   // (Others are possible if there are conversions, but this seems clearest).
613   if (CapturedInfo.HasReturnStmt) {
614     // If the return is conditional, neither replacing the code with
615     // `extracted()` nor `return extracted()` is correct.
616     if (!CapturedInfo.AlwaysReturns)
617       return false;
618     QualType Ret = EnclosingFunc.getReturnType();
619     // Once we support members, it'd be nice to support e.g. extracting a method
620     // of Foo<T> that returns T. But it's not clear when that's safe.
621     if (Ret->isDependentType())
622       return false;
623     ExtractedFunc.ReturnType = Ret;
624     return true;
625   }
626   // FIXME: Generate new return statement if needed.
627   ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
628   return true;
629 }
630 
631 // FIXME: add support for adding other function return types besides void.
632 // FIXME: assign the value returned by non void extracted function.
getExtractedFunction(ExtractionZone & ExtZone,const SourceManager & SM,const LangOptions & LangOpts)633 llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
634                                                  const SourceManager &SM,
635                                                  const LangOptions &LangOpts) {
636   CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
637   // Bail out if any break of continue exists
638   if (CapturedInfo.BrokenControlFlow)
639     return llvm::createStringError(llvm::inconvertibleErrorCode(),
640                                    +"Cannot extract break/continue without "
641                                     "corresponding loop/switch statement.");
642   NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts));
643   ExtractedFunc.BodyRange = ExtZone.ZoneRange;
644   ExtractedFunc.InsertionPoint = ExtZone.getInsertionPoint();
645   ExtractedFunc.EnclosingFuncContext =
646       ExtZone.EnclosingFunction->getDeclContext();
647   ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns;
648   if (!createParameters(ExtractedFunc, CapturedInfo) ||
649       !generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction,
650                                 CapturedInfo))
651     return llvm::createStringError(llvm::inconvertibleErrorCode(),
652                                    +"Too complex to extract.");
653   return ExtractedFunc;
654 }
655 
656 class ExtractFunction : public Tweak {
657 public:
658   const char *id() const override final;
659   bool prepare(const Selection &Inputs) override;
660   Expected<Effect> apply(const Selection &Inputs) override;
title() const661   std::string title() const override { return "Extract to function"; }
intent() const662   Intent intent() const override { return Refactor; }
663 
664 private:
665   ExtractionZone ExtZone;
666 };
667 
REGISTER_TWEAK(ExtractFunction)668 REGISTER_TWEAK(ExtractFunction)
669 tooling::Replacement replaceWithFuncCall(const NewFunction &ExtractedFunc,
670                                          const SourceManager &SM,
671                                          const LangOptions &LangOpts) {
672   std::string FuncCall = ExtractedFunc.renderCall();
673   return tooling::Replacement(
674       SM, CharSourceRange(ExtractedFunc.BodyRange, false), FuncCall, LangOpts);
675 }
676 
createFunctionDefinition(const NewFunction & ExtractedFunc,const SourceManager & SM)677 tooling::Replacement createFunctionDefinition(const NewFunction &ExtractedFunc,
678                                               const SourceManager &SM) {
679   std::string FunctionDef = ExtractedFunc.renderDefinition(SM);
680   return tooling::Replacement(SM, ExtractedFunc.InsertionPoint, 0, FunctionDef);
681 }
682 
prepare(const Selection & Inputs)683 bool ExtractFunction::prepare(const Selection &Inputs) {
684   const Node *CommonAnc = Inputs.ASTSelection.commonAncestor();
685   const SourceManager &SM = Inputs.AST->getSourceManager();
686   const LangOptions &LangOpts = Inputs.AST->getLangOpts();
687   if (auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts)) {
688     ExtZone = std::move(*MaybeExtZone);
689     return true;
690   }
691   return false;
692 }
693 
apply(const Selection & Inputs)694 Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) {
695   const SourceManager &SM = Inputs.AST->getSourceManager();
696   const LangOptions &LangOpts = Inputs.AST->getLangOpts();
697   auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
698   // FIXME: Add more types of errors.
699   if (!ExtractedFunc)
700     return ExtractedFunc.takeError();
701   tooling::Replacements Result;
702   if (auto Err = Result.add(createFunctionDefinition(*ExtractedFunc, SM)))
703     return std::move(Err);
704   if (auto Err = Result.add(replaceWithFuncCall(*ExtractedFunc, SM, LangOpts)))
705     return std::move(Err);
706   return Effect::mainFileEdit(SM, std::move(Result));
707 }
708 
709 } // namespace
710 } // namespace clangd
711 } // namespace clang
712