1 //===--- Stencil.cpp - Stencil implementation -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Tooling/Transformer/Stencil.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/AST/ASTTypeTraits.h"
12 #include "clang/AST/Expr.h"
13 #include "clang/ASTMatchers/ASTMatchFinder.h"
14 #include "clang/ASTMatchers/ASTMatchers.h"
15 #include "clang/Basic/SourceLocation.h"
16 #include "clang/Lex/Lexer.h"
17 #include "clang/Tooling/Transformer/SourceCode.h"
18 #include "clang/Tooling/Transformer/SourceCodeBuilders.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/Support/Errc.h"
22 #include "llvm/Support/Error.h"
23 #include <atomic>
24 #include <memory>
25 #include <string>
26 
27 using namespace clang;
28 using namespace transformer;
29 
30 using ast_matchers::MatchFinder;
31 using llvm::errc;
32 using llvm::Error;
33 using llvm::Expected;
34 using llvm::StringError;
35 
36 static llvm::Expected<DynTypedNode>
37 getNode(const ast_matchers::BoundNodes &Nodes, StringRef Id) {
38   auto &NodesMap = Nodes.getMap();
39   auto It = NodesMap.find(Id);
40   if (It == NodesMap.end())
41     return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
42                                                "Id not bound: " + Id);
43   return It->second;
44 }
45 
46 static Error printNode(StringRef Id, const MatchFinder::MatchResult &Match,
47                        std::string *Result) {
48   std::string Output;
49   llvm::raw_string_ostream Os(Output);
50   auto NodeOrErr = getNode(Match.Nodes, Id);
51   if (auto Err = NodeOrErr.takeError())
52     return Err;
53   NodeOrErr->print(Os, PrintingPolicy(Match.Context->getLangOpts()));
54   *Result += Os.str();
55   return Error::success();
56 }
57 
58 // FIXME: Consider memoizing this function using the `ASTContext`.
59 static bool isSmartPointerType(QualType Ty, ASTContext &Context) {
60   using namespace ::clang::ast_matchers;
61 
62   // Optimization: hard-code common smart-pointer types. This can/should be
63   // removed if we start caching the results of this function.
64   auto KnownSmartPointer =
65       cxxRecordDecl(hasAnyName("::std::unique_ptr", "::std::shared_ptr"));
66   const auto QuacksLikeASmartPointer = cxxRecordDecl(
67       hasMethod(cxxMethodDecl(hasOverloadedOperatorName("->"),
68                               returns(qualType(pointsTo(type()))))),
69       hasMethod(cxxMethodDecl(hasOverloadedOperatorName("*"),
70                               returns(qualType(references(type()))))));
71   const auto SmartPointer = qualType(hasDeclaration(
72       cxxRecordDecl(anyOf(KnownSmartPointer, QuacksLikeASmartPointer))));
73   return match(SmartPointer, Ty, Context).size() > 0;
74 }
75 
76 // Identifies use of `operator*` on smart pointers, and returns the underlying
77 // smart-pointer expression; otherwise, returns null.
78 static const Expr *isSmartDereference(const Expr &E, ASTContext &Context) {
79   using namespace ::clang::ast_matchers;
80 
81   const auto HasOverloadedArrow = cxxRecordDecl(hasMethod(cxxMethodDecl(
82       hasOverloadedOperatorName("->"), returns(qualType(pointsTo(type()))))));
83   // Verify it is a smart pointer by finding `operator->` in the class
84   // declaration.
85   auto Deref = cxxOperatorCallExpr(
86       hasOverloadedOperatorName("*"), hasUnaryOperand(expr().bind("arg")),
87       callee(cxxMethodDecl(ofClass(HasOverloadedArrow))));
88   return selectFirst<Expr>("arg", match(Deref, E, Context));
89 }
90 
91 namespace {
92 // An arbitrary fragment of code within a stencil.
93 class RawTextStencil : public StencilInterface {
94   std::string Text;
95 
96 public:
97   explicit RawTextStencil(std::string T) : Text(std::move(T)) {}
98 
99   std::string toString() const override {
100     std::string Result;
101     llvm::raw_string_ostream OS(Result);
102     OS << "\"";
103     OS.write_escaped(Text);
104     OS << "\"";
105     OS.flush();
106     return Result;
107   }
108 
109   Error eval(const MatchFinder::MatchResult &Match,
110              std::string *Result) const override {
111     Result->append(Text);
112     return Error::success();
113   }
114 };
115 
116 // A debugging operation to dump the AST for a particular (bound) AST node.
117 class DebugPrintNodeStencil : public StencilInterface {
118   std::string Id;
119 
120 public:
121   explicit DebugPrintNodeStencil(std::string S) : Id(std::move(S)) {}
122 
123   std::string toString() const override {
124     return (llvm::Twine("dPrint(\"") + Id + "\")").str();
125   }
126 
127   Error eval(const MatchFinder::MatchResult &Match,
128              std::string *Result) const override {
129     return printNode(Id, Match, Result);
130   }
131 };
132 
133 // Operators that take a single node Id as an argument.
134 enum class UnaryNodeOperator {
135   Parens,
136   Deref,
137   MaybeDeref,
138   AddressOf,
139   MaybeAddressOf,
140   Describe,
141 };
142 
143 // Generic container for stencil operations with a (single) node-id argument.
144 class UnaryOperationStencil : public StencilInterface {
145   UnaryNodeOperator Op;
146   std::string Id;
147 
148 public:
149   UnaryOperationStencil(UnaryNodeOperator Op, std::string Id)
150       : Op(Op), Id(std::move(Id)) {}
151 
152   std::string toString() const override {
153     StringRef OpName;
154     switch (Op) {
155     case UnaryNodeOperator::Parens:
156       OpName = "expression";
157       break;
158     case UnaryNodeOperator::Deref:
159       OpName = "deref";
160       break;
161     case UnaryNodeOperator::MaybeDeref:
162       OpName = "maybeDeref";
163       break;
164     case UnaryNodeOperator::AddressOf:
165       OpName = "addressOf";
166       break;
167     case UnaryNodeOperator::MaybeAddressOf:
168       OpName = "maybeAddressOf";
169       break;
170     case UnaryNodeOperator::Describe:
171       OpName = "describe";
172       break;
173     }
174     return (OpName + "(\"" + Id + "\")").str();
175   }
176 
177   Error eval(const MatchFinder::MatchResult &Match,
178              std::string *Result) const override {
179     // The `Describe` operation can be applied to any node, not just
180     // expressions, so it is handled here, separately.
181     if (Op == UnaryNodeOperator::Describe)
182       return printNode(Id, Match, Result);
183 
184     const auto *E = Match.Nodes.getNodeAs<Expr>(Id);
185     if (E == nullptr)
186       return llvm::make_error<StringError>(errc::invalid_argument,
187                                            "Id not bound or not Expr: " + Id);
188     llvm::Optional<std::string> Source;
189     switch (Op) {
190     case UnaryNodeOperator::Parens:
191       Source = tooling::buildParens(*E, *Match.Context);
192       break;
193     case UnaryNodeOperator::Deref:
194       Source = tooling::buildDereference(*E, *Match.Context);
195       break;
196     case UnaryNodeOperator::MaybeDeref:
197       if (E->getType()->isAnyPointerType() ||
198           isSmartPointerType(E->getType(), *Match.Context)) {
199         // Strip off any operator->. This can only occur inside an actual arrow
200         // member access, so we treat it as equivalent to an actual object
201         // expression.
202         if (const auto *OpCall = dyn_cast<clang::CXXOperatorCallExpr>(E)) {
203           if (OpCall->getOperator() == clang::OO_Arrow &&
204               OpCall->getNumArgs() == 1) {
205             E = OpCall->getArg(0);
206           }
207         }
208         Source = tooling::buildDereference(*E, *Match.Context);
209         break;
210       }
211       *Result += tooling::getText(*E, *Match.Context);
212       return Error::success();
213     case UnaryNodeOperator::AddressOf:
214       Source = tooling::buildAddressOf(*E, *Match.Context);
215       break;
216     case UnaryNodeOperator::MaybeAddressOf:
217       if (E->getType()->isAnyPointerType() ||
218           isSmartPointerType(E->getType(), *Match.Context)) {
219         // Strip off any operator->. This can only occur inside an actual arrow
220         // member access, so we treat it as equivalent to an actual object
221         // expression.
222         if (const auto *OpCall = dyn_cast<clang::CXXOperatorCallExpr>(E)) {
223           if (OpCall->getOperator() == clang::OO_Arrow &&
224               OpCall->getNumArgs() == 1) {
225             E = OpCall->getArg(0);
226           }
227         }
228         *Result += tooling::getText(*E, *Match.Context);
229         return Error::success();
230       }
231       Source = tooling::buildAddressOf(*E, *Match.Context);
232       break;
233     case UnaryNodeOperator::Describe:
234       llvm_unreachable("This case is handled at the start of the function");
235     }
236     if (!Source)
237       return llvm::make_error<StringError>(
238           errc::invalid_argument,
239           "Could not construct expression source from ID: " + Id);
240     *Result += *Source;
241     return Error::success();
242   }
243 };
244 
245 // The fragment of code corresponding to the selected range.
246 class SelectorStencil : public StencilInterface {
247   RangeSelector Selector;
248 
249 public:
250   explicit SelectorStencil(RangeSelector S) : Selector(std::move(S)) {}
251 
252   std::string toString() const override { return "selection(...)"; }
253 
254   Error eval(const MatchFinder::MatchResult &Match,
255              std::string *Result) const override {
256     auto RawRange = Selector(Match);
257     if (!RawRange)
258       return RawRange.takeError();
259     CharSourceRange Range = Lexer::makeFileCharRange(
260         *RawRange, *Match.SourceManager, Match.Context->getLangOpts());
261     if (Range.isInvalid()) {
262       // Validate the original range to attempt to get a meaningful error
263       // message. If it's valid, then something else is the cause and we just
264       // return the generic failure message.
265       if (auto Err =
266               tooling::validateEditRange(*RawRange, *Match.SourceManager))
267         return handleErrors(std::move(Err), [](std::unique_ptr<StringError> E) {
268           assert(E->convertToErrorCode() ==
269                      llvm::make_error_code(errc::invalid_argument) &&
270                  "Validation errors must carry the invalid_argument code");
271           return llvm::createStringError(
272               errc::invalid_argument,
273               "selected range could not be resolved to a valid source range; " +
274                   E->getMessage());
275         });
276       return llvm::createStringError(
277           errc::invalid_argument,
278           "selected range could not be resolved to a valid source range");
279     }
280     // Validate `Range`, because `makeFileCharRange` accepts some ranges that
281     // `validateEditRange` rejects.
282     if (auto Err = tooling::validateEditRange(Range, *Match.SourceManager))
283       return joinErrors(
284           llvm::createStringError(errc::invalid_argument,
285                                   "selected range is not valid for editing"),
286           std::move(Err));
287     *Result += tooling::getText(Range, *Match.Context);
288     return Error::success();
289   }
290 };
291 
292 // A stencil operation to build a member access `e.m` or `e->m`, as appropriate.
293 class AccessStencil : public StencilInterface {
294   std::string BaseId;
295   Stencil Member;
296 
297 public:
298   AccessStencil(StringRef BaseId, Stencil Member)
299       : BaseId(std::string(BaseId)), Member(std::move(Member)) {}
300 
301   std::string toString() const override {
302     return (llvm::Twine("access(\"") + BaseId + "\", " + Member->toString() +
303             ")")
304         .str();
305   }
306 
307   Error eval(const MatchFinder::MatchResult &Match,
308              std::string *Result) const override {
309     const auto *E = Match.Nodes.getNodeAs<Expr>(BaseId);
310     if (E == nullptr)
311       return llvm::make_error<StringError>(errc::invalid_argument,
312                                            "Id not bound: " + BaseId);
313     if (!E->isImplicitCXXThis()) {
314       llvm::Optional<std::string> S;
315       if (E->getType()->isAnyPointerType() ||
316           isSmartPointerType(E->getType(), *Match.Context)) {
317         // Strip off any operator->. This can only occur inside an actual arrow
318         // member access, so we treat it as equivalent to an actual object
319         // expression.
320         if (const auto *OpCall = dyn_cast<clang::CXXOperatorCallExpr>(E)) {
321           if (OpCall->getOperator() == clang::OO_Arrow &&
322               OpCall->getNumArgs() == 1) {
323             E = OpCall->getArg(0);
324           }
325         }
326         S = tooling::buildArrow(*E, *Match.Context);
327       } else if (const auto *Operand = isSmartDereference(*E, *Match.Context)) {
328         // `buildDot` already handles the built-in dereference operator, so we
329         // only need to catch overloaded `operator*`.
330         S = tooling::buildArrow(*Operand, *Match.Context);
331       } else {
332         S = tooling::buildDot(*E, *Match.Context);
333       }
334       if (S.hasValue())
335         *Result += *S;
336       else
337         return llvm::make_error<StringError>(
338             errc::invalid_argument,
339             "Could not construct object text from ID: " + BaseId);
340     }
341     return Member->eval(Match, Result);
342   }
343 };
344 
345 class IfBoundStencil : public StencilInterface {
346   std::string Id;
347   Stencil TrueStencil;
348   Stencil FalseStencil;
349 
350 public:
351   IfBoundStencil(StringRef Id, Stencil TrueStencil, Stencil FalseStencil)
352       : Id(std::string(Id)), TrueStencil(std::move(TrueStencil)),
353         FalseStencil(std::move(FalseStencil)) {}
354 
355   std::string toString() const override {
356     return (llvm::Twine("ifBound(\"") + Id + "\", " + TrueStencil->toString() +
357             ", " + FalseStencil->toString() + ")")
358         .str();
359   }
360 
361   Error eval(const MatchFinder::MatchResult &Match,
362              std::string *Result) const override {
363     auto &M = Match.Nodes.getMap();
364     return (M.find(Id) != M.end() ? TrueStencil : FalseStencil)
365         ->eval(Match, Result);
366   }
367 };
368 
369 class SequenceStencil : public StencilInterface {
370   std::vector<Stencil> Stencils;
371 
372 public:
373   SequenceStencil(std::vector<Stencil> Stencils)
374       : Stencils(std::move(Stencils)) {}
375 
376   std::string toString() const override {
377     llvm::SmallVector<std::string, 2> Parts;
378     Parts.reserve(Stencils.size());
379     for (const auto &S : Stencils)
380       Parts.push_back(S->toString());
381     return (llvm::Twine("seq(") + llvm::join(Parts, ", ") + ")").str();
382   }
383 
384   Error eval(const MatchFinder::MatchResult &Match,
385              std::string *Result) const override {
386     for (const auto &S : Stencils)
387       if (auto Err = S->eval(Match, Result))
388         return Err;
389     return Error::success();
390   }
391 };
392 
393 class RunStencil : public StencilInterface {
394   MatchConsumer<std::string> Consumer;
395 
396 public:
397   explicit RunStencil(MatchConsumer<std::string> C) : Consumer(std::move(C)) {}
398 
399   std::string toString() const override { return "run(...)"; }
400 
401   Error eval(const MatchFinder::MatchResult &Match,
402              std::string *Result) const override {
403 
404     Expected<std::string> Value = Consumer(Match);
405     if (!Value)
406       return Value.takeError();
407     *Result += *Value;
408     return Error::success();
409   }
410 };
411 } // namespace
412 
413 Stencil transformer::detail::makeStencil(StringRef Text) {
414   return std::make_shared<RawTextStencil>(std::string(Text));
415 }
416 
417 Stencil transformer::detail::makeStencil(RangeSelector Selector) {
418   return std::make_shared<SelectorStencil>(std::move(Selector));
419 }
420 
421 Stencil transformer::dPrint(StringRef Id) {
422   return std::make_shared<DebugPrintNodeStencil>(std::string(Id));
423 }
424 
425 Stencil transformer::expression(llvm::StringRef Id) {
426   return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::Parens,
427                                                  std::string(Id));
428 }
429 
430 Stencil transformer::deref(llvm::StringRef ExprId) {
431   return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::Deref,
432                                                  std::string(ExprId));
433 }
434 
435 Stencil transformer::maybeDeref(llvm::StringRef ExprId) {
436   return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::MaybeDeref,
437                                                  std::string(ExprId));
438 }
439 
440 Stencil transformer::addressOf(llvm::StringRef ExprId) {
441   return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::AddressOf,
442                                                  std::string(ExprId));
443 }
444 
445 Stencil transformer::maybeAddressOf(llvm::StringRef ExprId) {
446   return std::make_shared<UnaryOperationStencil>(
447       UnaryNodeOperator::MaybeAddressOf, std::string(ExprId));
448 }
449 
450 Stencil transformer::describe(StringRef Id) {
451   return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::Describe,
452                                                  std::string(Id));
453 }
454 
455 Stencil transformer::access(StringRef BaseId, Stencil Member) {
456   return std::make_shared<AccessStencil>(BaseId, std::move(Member));
457 }
458 
459 Stencil transformer::ifBound(StringRef Id, Stencil TrueStencil,
460                              Stencil FalseStencil) {
461   return std::make_shared<IfBoundStencil>(Id, std::move(TrueStencil),
462                                           std::move(FalseStencil));
463 }
464 
465 Stencil transformer::run(MatchConsumer<std::string> Fn) {
466   return std::make_shared<RunStencil>(std::move(Fn));
467 }
468 
469 Stencil transformer::catVector(std::vector<Stencil> Parts) {
470   // Only one argument, so don't wrap in sequence.
471   if (Parts.size() == 1)
472     return std::move(Parts[0]);
473   return std::make_shared<SequenceStencil>(std::move(Parts));
474 }
475