1 //===- Marshallers.h - Generic matcher function marshallers -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Functions templates and classes to wrap matcher construct functions.
11 ///
12 /// A collection of template function and classes that provide a generic
13 /// marshalling layer on top of matcher construct functions.
14 /// These are used by the registry to export all marshaller constructors with
15 /// the same generic interface.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #ifndef LLVM_CLANG_LIB_ASTMATCHERS_DYNAMIC_MARSHALLERS_H
20 #define LLVM_CLANG_LIB_ASTMATCHERS_DYNAMIC_MARSHALLERS_H
21 
22 #include "clang/AST/ASTTypeTraits.h"
23 #include "clang/AST/OperationKinds.h"
24 #include "clang/ASTMatchers/ASTMatchersInternal.h"
25 #include "clang/ASTMatchers/Dynamic/Diagnostics.h"
26 #include "clang/ASTMatchers/Dynamic/VariantValue.h"
27 #include "clang/Basic/AttrKinds.h"
28 #include "clang/Basic/LLVM.h"
29 #include "clang/Basic/OpenMPKinds.h"
30 #include "clang/Basic/TypeTraits.h"
31 #include "llvm/ADT/ArrayRef.h"
32 #include "llvm/ADT/None.h"
33 #include "llvm/ADT/Optional.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/ADT/Twine.h"
38 #include "llvm/Support/Regex.h"
39 #include <cassert>
40 #include <cstddef>
41 #include <iterator>
42 #include <limits>
43 #include <memory>
44 #include <string>
45 #include <utility>
46 #include <vector>
47 
48 namespace clang {
49 namespace ast_matchers {
50 namespace dynamic {
51 namespace internal {
52 
53 /// Helper template class to just from argument type to the right is/get
54 ///   functions in VariantValue.
55 /// Used to verify and extract the matcher arguments below.
56 template <class T> struct ArgTypeTraits;
57 template <class T> struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {
58 };
59 
60 template <> struct ArgTypeTraits<std::string> {
61   static bool is(const VariantValue &Value) { return Value.isString(); }
62 
63   static const std::string &get(const VariantValue &Value) {
64     return Value.getString();
65   }
66 
67   static ArgKind getKind() {
68     return ArgKind(ArgKind::AK_String);
69   }
70 
71   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
72     return llvm::None;
73   }
74 };
75 
76 template <>
77 struct ArgTypeTraits<StringRef> : public ArgTypeTraits<std::string> {
78 };
79 
80 template <class T> struct ArgTypeTraits<ast_matchers::internal::Matcher<T>> {
81   static bool is(const VariantValue &Value) {
82     return Value.isMatcher() && Value.getMatcher().hasTypedMatcher<T>();
83   }
84 
85   static ast_matchers::internal::Matcher<T> get(const VariantValue &Value) {
86     return Value.getMatcher().getTypedMatcher<T>();
87   }
88 
89   static ArgKind getKind() {
90     return ArgKind(ASTNodeKind::getFromNodeKind<T>());
91   }
92 
93   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
94     return llvm::None;
95   }
96 };
97 
98 template <> struct ArgTypeTraits<bool> {
99   static bool is(const VariantValue &Value) { return Value.isBoolean(); }
100 
101   static bool get(const VariantValue &Value) {
102     return Value.getBoolean();
103   }
104 
105   static ArgKind getKind() {
106     return ArgKind(ArgKind::AK_Boolean);
107   }
108 
109   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
110     return llvm::None;
111   }
112 };
113 
114 template <> struct ArgTypeTraits<double> {
115   static bool is(const VariantValue &Value) { return Value.isDouble(); }
116 
117   static double get(const VariantValue &Value) {
118     return Value.getDouble();
119   }
120 
121   static ArgKind getKind() {
122     return ArgKind(ArgKind::AK_Double);
123   }
124 
125   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
126     return llvm::None;
127   }
128 };
129 
130 template <> struct ArgTypeTraits<unsigned> {
131   static bool is(const VariantValue &Value) { return Value.isUnsigned(); }
132 
133   static unsigned get(const VariantValue &Value) {
134     return Value.getUnsigned();
135   }
136 
137   static ArgKind getKind() {
138     return ArgKind(ArgKind::AK_Unsigned);
139   }
140 
141   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
142     return llvm::None;
143   }
144 };
145 
146 template <> struct ArgTypeTraits<attr::Kind> {
147 private:
148   static Optional<attr::Kind> getAttrKind(llvm::StringRef AttrKind) {
149     return llvm::StringSwitch<Optional<attr::Kind>>(AttrKind)
150 #define ATTR(X) .Case("attr::" #X, attr:: X)
151 #include "clang/Basic/AttrList.inc"
152         .Default(llvm::None);
153   }
154 
155 public:
156   static bool is(const VariantValue &Value) {
157     return Value.isString() && getAttrKind(Value.getString());
158   }
159 
160   static attr::Kind get(const VariantValue &Value) {
161     return *getAttrKind(Value.getString());
162   }
163 
164   static ArgKind getKind() {
165     return ArgKind(ArgKind::AK_String);
166   }
167 
168   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
169 };
170 
171 template <> struct ArgTypeTraits<CastKind> {
172 private:
173   static Optional<CastKind> getCastKind(llvm::StringRef AttrKind) {
174     return llvm::StringSwitch<Optional<CastKind>>(AttrKind)
175 #define CAST_OPERATION(Name) .Case("CK_" #Name, CK_##Name)
176 #include "clang/AST/OperationKinds.def"
177         .Default(llvm::None);
178   }
179 
180 public:
181   static bool is(const VariantValue &Value) {
182     return Value.isString() && getCastKind(Value.getString());
183   }
184 
185   static CastKind get(const VariantValue &Value) {
186     return *getCastKind(Value.getString());
187   }
188 
189   static ArgKind getKind() {
190     return ArgKind(ArgKind::AK_String);
191   }
192 
193   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
194 };
195 
196 template <> struct ArgTypeTraits<llvm::Regex::RegexFlags> {
197 private:
198   static Optional<llvm::Regex::RegexFlags> getFlags(llvm::StringRef Flags);
199 
200 public:
201   static bool is(const VariantValue &Value) {
202     return Value.isString() && getFlags(Value.getString());
203   }
204 
205   static llvm::Regex::RegexFlags get(const VariantValue &Value) {
206     return *getFlags(Value.getString());
207   }
208 
209   static ArgKind getKind() { return ArgKind(ArgKind::AK_String); }
210 
211   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
212 };
213 
214 template <> struct ArgTypeTraits<OpenMPClauseKind> {
215 private:
216   static Optional<OpenMPClauseKind> getClauseKind(llvm::StringRef ClauseKind) {
217     return llvm::StringSwitch<Optional<OpenMPClauseKind>>(ClauseKind)
218 #define OMP_CLAUSE_CLASS(Enum, Str, Class) .Case(#Enum, llvm::omp::Clause::Enum)
219 #include "llvm/Frontend/OpenMP/OMPKinds.def"
220         .Default(llvm::None);
221   }
222 
223 public:
224   static bool is(const VariantValue &Value) {
225     return Value.isString() && getClauseKind(Value.getString());
226   }
227 
228   static OpenMPClauseKind get(const VariantValue &Value) {
229     return *getClauseKind(Value.getString());
230   }
231 
232   static ArgKind getKind() { return ArgKind(ArgKind::AK_String); }
233 
234   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
235 };
236 
237 template <> struct ArgTypeTraits<UnaryExprOrTypeTrait> {
238 private:
239   static Optional<UnaryExprOrTypeTrait>
240   getUnaryOrTypeTraitKind(llvm::StringRef ClauseKind) {
241     return llvm::StringSwitch<Optional<UnaryExprOrTypeTrait>>(ClauseKind)
242 #define UNARY_EXPR_OR_TYPE_TRAIT(Spelling, Name, Key)                          \
243   .Case("UETT_" #Name, UETT_##Name)
244 #define CXX11_UNARY_EXPR_OR_TYPE_TRAIT(Spelling, Name, Key)                    \
245   .Case("UETT_" #Name, UETT_##Name)
246 #include "clang/Basic/TokenKinds.def"
247         .Default(llvm::None);
248   }
249 
250 public:
251   static bool is(const VariantValue &Value) {
252     return Value.isString() && getUnaryOrTypeTraitKind(Value.getString());
253   }
254 
255   static UnaryExprOrTypeTrait get(const VariantValue &Value) {
256     return *getUnaryOrTypeTraitKind(Value.getString());
257   }
258 
259   static ArgKind getKind() { return ArgKind(ArgKind::AK_String); }
260 
261   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
262 };
263 
264 /// Matcher descriptor interface.
265 ///
266 /// Provides a \c create() method that constructs the matcher from the provided
267 /// arguments, and various other methods for type introspection.
268 class MatcherDescriptor {
269 public:
270   virtual ~MatcherDescriptor() = default;
271 
272   virtual VariantMatcher create(SourceRange NameRange,
273                                 ArrayRef<ParserValue> Args,
274                                 Diagnostics *Error) const = 0;
275 
276   /// Returns whether the matcher is variadic. Variadic matchers can take any
277   /// number of arguments, but they must be of the same type.
278   virtual bool isVariadic() const = 0;
279 
280   /// Returns the number of arguments accepted by the matcher if not variadic.
281   virtual unsigned getNumArgs() const = 0;
282 
283   /// Given that the matcher is being converted to type \p ThisKind, append the
284   /// set of argument types accepted for argument \p ArgNo to \p ArgKinds.
285   // FIXME: We should provide the ability to constrain the output of this
286   // function based on the types of other matcher arguments.
287   virtual void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
288                            std::vector<ArgKind> &ArgKinds) const = 0;
289 
290   /// Returns whether this matcher is convertible to the given type.  If it is
291   /// so convertible, store in *Specificity a value corresponding to the
292   /// "specificity" of the converted matcher to the given context, and in
293   /// *LeastDerivedKind the least derived matcher kind which would result in the
294   /// same matcher overload.  Zero specificity indicates that this conversion
295   /// would produce a trivial matcher that will either always or never match.
296   /// Such matchers are excluded from code completion results.
297   virtual bool
298   isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity = nullptr,
299                   ASTNodeKind *LeastDerivedKind = nullptr) const = 0;
300 
301   /// Returns whether the matcher will, given a matcher of any type T, yield a
302   /// matcher of type T.
303   virtual bool isPolymorphic() const { return false; }
304 };
305 
306 inline bool isRetKindConvertibleTo(ArrayRef<ASTNodeKind> RetKinds,
307                                    ASTNodeKind Kind, unsigned *Specificity,
308                                    ASTNodeKind *LeastDerivedKind) {
309   for (const ASTNodeKind &NodeKind : RetKinds) {
310     if (ArgKind(NodeKind).isConvertibleTo(Kind, Specificity)) {
311       if (LeastDerivedKind)
312         *LeastDerivedKind = NodeKind;
313       return true;
314     }
315   }
316   return false;
317 }
318 
319 /// Simple callback implementation. Marshaller and function are provided.
320 ///
321 /// This class wraps a function of arbitrary signature and a marshaller
322 /// function into a MatcherDescriptor.
323 /// The marshaller is in charge of taking the VariantValue arguments, checking
324 /// their types, unpacking them and calling the underlying function.
325 class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
326 public:
327   using MarshallerType = VariantMatcher (*)(void (*Func)(),
328                                             StringRef MatcherName,
329                                             SourceRange NameRange,
330                                             ArrayRef<ParserValue> Args,
331                                             Diagnostics *Error);
332 
333   /// \param Marshaller Function to unpack the arguments and call \c Func
334   /// \param Func Matcher construct function. This is the function that
335   ///   compile-time matcher expressions would use to create the matcher.
336   /// \param RetKinds The list of matcher types to which the matcher is
337   ///   convertible.
338   /// \param ArgKinds The types of the arguments this matcher takes.
339   FixedArgCountMatcherDescriptor(MarshallerType Marshaller, void (*Func)(),
340                                  StringRef MatcherName,
341                                  ArrayRef<ASTNodeKind> RetKinds,
342                                  ArrayRef<ArgKind> ArgKinds)
343       : Marshaller(Marshaller), Func(Func), MatcherName(MatcherName),
344         RetKinds(RetKinds.begin(), RetKinds.end()),
345         ArgKinds(ArgKinds.begin(), ArgKinds.end()) {}
346 
347   VariantMatcher create(SourceRange NameRange,
348                         ArrayRef<ParserValue> Args,
349                         Diagnostics *Error) const override {
350     return Marshaller(Func, MatcherName, NameRange, Args, Error);
351   }
352 
353   bool isVariadic() const override { return false; }
354   unsigned getNumArgs() const override { return ArgKinds.size(); }
355 
356   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
357                    std::vector<ArgKind> &Kinds) const override {
358     Kinds.push_back(ArgKinds[ArgNo]);
359   }
360 
361   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
362                        ASTNodeKind *LeastDerivedKind) const override {
363     return isRetKindConvertibleTo(RetKinds, Kind, Specificity,
364                                   LeastDerivedKind);
365   }
366 
367 private:
368   const MarshallerType Marshaller;
369   void (* const Func)();
370   const std::string MatcherName;
371   const std::vector<ASTNodeKind> RetKinds;
372   const std::vector<ArgKind> ArgKinds;
373 };
374 
375 /// Helper methods to extract and merge all possible typed matchers
376 /// out of the polymorphic object.
377 template <class PolyMatcher>
378 static void mergePolyMatchers(const PolyMatcher &Poly,
379                               std::vector<DynTypedMatcher> &Out,
380                               ast_matchers::internal::EmptyTypeList) {}
381 
382 template <class PolyMatcher, class TypeList>
383 static void mergePolyMatchers(const PolyMatcher &Poly,
384                               std::vector<DynTypedMatcher> &Out, TypeList) {
385   Out.push_back(ast_matchers::internal::Matcher<typename TypeList::head>(Poly));
386   mergePolyMatchers(Poly, Out, typename TypeList::tail());
387 }
388 
389 /// Convert the return values of the functions into a VariantMatcher.
390 ///
391 /// There are 2 cases right now: The return value is a Matcher<T> or is a
392 /// polymorphic matcher. For the former, we just construct the VariantMatcher.
393 /// For the latter, we instantiate all the possible Matcher<T> of the poly
394 /// matcher.
395 inline VariantMatcher outvalueToVariantMatcher(const DynTypedMatcher &Matcher) {
396   return VariantMatcher::SingleMatcher(Matcher);
397 }
398 
399 template <typename T>
400 static VariantMatcher outvalueToVariantMatcher(const T &PolyMatcher,
401                                                typename T::ReturnTypes * =
402                                                    nullptr) {
403   std::vector<DynTypedMatcher> Matchers;
404   mergePolyMatchers(PolyMatcher, Matchers, typename T::ReturnTypes());
405   VariantMatcher Out = VariantMatcher::PolymorphicMatcher(std::move(Matchers));
406   return Out;
407 }
408 
409 template <typename T>
410 inline void
411 buildReturnTypeVectorFromTypeList(std::vector<ASTNodeKind> &RetTypes) {
412   RetTypes.push_back(ASTNodeKind::getFromNodeKind<typename T::head>());
413   buildReturnTypeVectorFromTypeList<typename T::tail>(RetTypes);
414 }
415 
416 template <>
417 inline void
418 buildReturnTypeVectorFromTypeList<ast_matchers::internal::EmptyTypeList>(
419     std::vector<ASTNodeKind> &RetTypes) {}
420 
421 template <typename T>
422 struct BuildReturnTypeVector {
423   static void build(std::vector<ASTNodeKind> &RetTypes) {
424     buildReturnTypeVectorFromTypeList<typename T::ReturnTypes>(RetTypes);
425   }
426 };
427 
428 template <typename T>
429 struct BuildReturnTypeVector<ast_matchers::internal::Matcher<T>> {
430   static void build(std::vector<ASTNodeKind> &RetTypes) {
431     RetTypes.push_back(ASTNodeKind::getFromNodeKind<T>());
432   }
433 };
434 
435 template <typename T>
436 struct BuildReturnTypeVector<ast_matchers::internal::BindableMatcher<T>> {
437   static void build(std::vector<ASTNodeKind> &RetTypes) {
438     RetTypes.push_back(ASTNodeKind::getFromNodeKind<T>());
439   }
440 };
441 
442 /// Variadic marshaller function.
443 template <typename ResultT, typename ArgT,
444           ResultT (*Func)(ArrayRef<const ArgT *>)>
445 VariantMatcher
446 variadicMatcherDescriptor(StringRef MatcherName, SourceRange NameRange,
447                           ArrayRef<ParserValue> Args, Diagnostics *Error) {
448   ArgT **InnerArgs = new ArgT *[Args.size()]();
449 
450   bool HasError = false;
451   for (size_t i = 0, e = Args.size(); i != e; ++i) {
452     using ArgTraits = ArgTypeTraits<ArgT>;
453 
454     const ParserValue &Arg = Args[i];
455     const VariantValue &Value = Arg.Value;
456     if (!ArgTraits::is(Value)) {
457       Error->addError(Arg.Range, Error->ET_RegistryWrongArgType)
458           << (i + 1) << ArgTraits::getKind().asString() << Value.getTypeAsString();
459       HasError = true;
460       break;
461     }
462     InnerArgs[i] = new ArgT(ArgTraits::get(Value));
463   }
464 
465   VariantMatcher Out;
466   if (!HasError) {
467     Out = outvalueToVariantMatcher(Func(llvm::makeArrayRef(InnerArgs,
468                                                            Args.size())));
469   }
470 
471   for (size_t i = 0, e = Args.size(); i != e; ++i) {
472     delete InnerArgs[i];
473   }
474   delete[] InnerArgs;
475   return Out;
476 }
477 
478 /// Matcher descriptor for variadic functions.
479 ///
480 /// This class simply wraps a VariadicFunction with the right signature to export
481 /// it as a MatcherDescriptor.
482 /// This allows us to have one implementation of the interface for as many free
483 /// functions as we want, reducing the number of symbols and size of the
484 /// object file.
485 class VariadicFuncMatcherDescriptor : public MatcherDescriptor {
486 public:
487   using RunFunc = VariantMatcher (*)(StringRef MatcherName,
488                                      SourceRange NameRange,
489                                      ArrayRef<ParserValue> Args,
490                                      Diagnostics *Error);
491 
492   template <typename ResultT, typename ArgT,
493             ResultT (*F)(ArrayRef<const ArgT *>)>
494   VariadicFuncMatcherDescriptor(
495       ast_matchers::internal::VariadicFunction<ResultT, ArgT, F> Func,
496       StringRef MatcherName)
497       : Func(&variadicMatcherDescriptor<ResultT, ArgT, F>),
498         MatcherName(MatcherName.str()),
499         ArgsKind(ArgTypeTraits<ArgT>::getKind()) {
500     BuildReturnTypeVector<ResultT>::build(RetKinds);
501   }
502 
503   VariantMatcher create(SourceRange NameRange,
504                         ArrayRef<ParserValue> Args,
505                         Diagnostics *Error) const override {
506     return Func(MatcherName, NameRange, Args, Error);
507   }
508 
509   bool isVariadic() const override { return true; }
510   unsigned getNumArgs() const override { return 0; }
511 
512   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
513                    std::vector<ArgKind> &Kinds) const override {
514     Kinds.push_back(ArgsKind);
515   }
516 
517   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
518                        ASTNodeKind *LeastDerivedKind) const override {
519     return isRetKindConvertibleTo(RetKinds, Kind, Specificity,
520                                   LeastDerivedKind);
521   }
522 
523 private:
524   const RunFunc Func;
525   const std::string MatcherName;
526   std::vector<ASTNodeKind> RetKinds;
527   const ArgKind ArgsKind;
528 };
529 
530 /// Return CK_Trivial when appropriate for VariadicDynCastAllOfMatchers.
531 class DynCastAllOfMatcherDescriptor : public VariadicFuncMatcherDescriptor {
532 public:
533   template <typename BaseT, typename DerivedT>
534   DynCastAllOfMatcherDescriptor(
535       ast_matchers::internal::VariadicDynCastAllOfMatcher<BaseT, DerivedT> Func,
536       StringRef MatcherName)
537       : VariadicFuncMatcherDescriptor(Func, MatcherName),
538         DerivedKind(ASTNodeKind::getFromNodeKind<DerivedT>()) {}
539 
540   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
541                        ASTNodeKind *LeastDerivedKind) const override {
542     // If Kind is not a base of DerivedKind, either DerivedKind is a base of
543     // Kind (in which case the match will always succeed) or Kind and
544     // DerivedKind are unrelated (in which case it will always fail), so set
545     // Specificity to 0.
546     if (VariadicFuncMatcherDescriptor::isConvertibleTo(Kind, Specificity,
547                                                  LeastDerivedKind)) {
548       if (Kind.isSame(DerivedKind) || !Kind.isBaseOf(DerivedKind)) {
549         if (Specificity)
550           *Specificity = 0;
551       }
552       return true;
553     } else {
554       return false;
555     }
556   }
557 
558 private:
559   const ASTNodeKind DerivedKind;
560 };
561 
562 /// Helper macros to check the arguments on all marshaller functions.
563 #define CHECK_ARG_COUNT(count)                                                 \
564   if (Args.size() != count) {                                                  \
565     Error->addError(NameRange, Error->ET_RegistryWrongArgCount)                \
566         << count << Args.size();                                               \
567     return VariantMatcher();                                                   \
568   }
569 
570 #define CHECK_ARG_TYPE(index, type)                                            \
571   if (!ArgTypeTraits<type>::is(Args[index].Value)) {                           \
572     if (llvm::Optional<std::string> BestGuess =                                \
573             ArgTypeTraits<type>::getBestGuess(Args[index].Value)) {            \
574       Error->addError(Args[index].Range,                                       \
575                       Error->ET_RegistryUnknownEnumWithReplace)                \
576           << index + 1 << Args[index].Value.getString() << *BestGuess;         \
577     } else {                                                                   \
578       Error->addError(Args[index].Range, Error->ET_RegistryWrongArgType)       \
579           << (index + 1) << ArgTypeTraits<type>::getKind().asString()          \
580           << Args[index].Value.getTypeAsString();                              \
581     }                                                                          \
582     return VariantMatcher();                                                   \
583   }
584 
585 /// 0-arg marshaller function.
586 template <typename ReturnType>
587 static VariantMatcher matcherMarshall0(void (*Func)(), StringRef MatcherName,
588                                        SourceRange NameRange,
589                                        ArrayRef<ParserValue> Args,
590                                        Diagnostics *Error) {
591   using FuncType = ReturnType (*)();
592   CHECK_ARG_COUNT(0);
593   return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)());
594 }
595 
596 /// 1-arg marshaller function.
597 template <typename ReturnType, typename ArgType1>
598 static VariantMatcher matcherMarshall1(void (*Func)(), StringRef MatcherName,
599                                        SourceRange NameRange,
600                                        ArrayRef<ParserValue> Args,
601                                        Diagnostics *Error) {
602   using FuncType = ReturnType (*)(ArgType1);
603   CHECK_ARG_COUNT(1);
604   CHECK_ARG_TYPE(0, ArgType1);
605   return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)(
606       ArgTypeTraits<ArgType1>::get(Args[0].Value)));
607 }
608 
609 /// 2-arg marshaller function.
610 template <typename ReturnType, typename ArgType1, typename ArgType2>
611 static VariantMatcher matcherMarshall2(void (*Func)(), StringRef MatcherName,
612                                        SourceRange NameRange,
613                                        ArrayRef<ParserValue> Args,
614                                        Diagnostics *Error) {
615   using FuncType = ReturnType (*)(ArgType1, ArgType2);
616   CHECK_ARG_COUNT(2);
617   CHECK_ARG_TYPE(0, ArgType1);
618   CHECK_ARG_TYPE(1, ArgType2);
619   return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)(
620       ArgTypeTraits<ArgType1>::get(Args[0].Value),
621       ArgTypeTraits<ArgType2>::get(Args[1].Value)));
622 }
623 
624 #undef CHECK_ARG_COUNT
625 #undef CHECK_ARG_TYPE
626 
627 /// Helper class used to collect all the possible overloads of an
628 ///   argument adaptative matcher function.
629 template <template <typename ToArg, typename FromArg> class ArgumentAdapterT,
630           typename FromTypes, typename ToTypes>
631 class AdaptativeOverloadCollector {
632 public:
633   AdaptativeOverloadCollector(
634       StringRef Name, std::vector<std::unique_ptr<MatcherDescriptor>> &Out)
635       : Name(Name), Out(Out) {
636     collect(FromTypes());
637   }
638 
639 private:
640   using AdaptativeFunc = ast_matchers::internal::ArgumentAdaptingMatcherFunc<
641       ArgumentAdapterT, FromTypes, ToTypes>;
642 
643   /// End case for the recursion
644   static void collect(ast_matchers::internal::EmptyTypeList) {}
645 
646   /// Recursive case. Get the overload for the head of the list, and
647   ///   recurse to the tail.
648   template <typename FromTypeList>
649   inline void collect(FromTypeList);
650 
651   StringRef Name;
652   std::vector<std::unique_ptr<MatcherDescriptor>> &Out;
653 };
654 
655 /// MatcherDescriptor that wraps multiple "overloads" of the same
656 ///   matcher.
657 ///
658 /// It will try every overload and generate appropriate errors for when none or
659 /// more than one overloads match the arguments.
660 class OverloadedMatcherDescriptor : public MatcherDescriptor {
661 public:
662   OverloadedMatcherDescriptor(
663       MutableArrayRef<std::unique_ptr<MatcherDescriptor>> Callbacks)
664       : Overloads(std::make_move_iterator(Callbacks.begin()),
665                   std::make_move_iterator(Callbacks.end())) {}
666 
667   ~OverloadedMatcherDescriptor() override = default;
668 
669   VariantMatcher create(SourceRange NameRange,
670                         ArrayRef<ParserValue> Args,
671                         Diagnostics *Error) const override {
672     std::vector<VariantMatcher> Constructed;
673     Diagnostics::OverloadContext Ctx(Error);
674     for (const auto &O : Overloads) {
675       VariantMatcher SubMatcher = O->create(NameRange, Args, Error);
676       if (!SubMatcher.isNull()) {
677         Constructed.push_back(SubMatcher);
678       }
679     }
680 
681     if (Constructed.empty()) return VariantMatcher(); // No overload matched.
682     // We ignore the errors if any matcher succeeded.
683     Ctx.revertErrors();
684     if (Constructed.size() > 1) {
685       // More than one constructed. It is ambiguous.
686       Error->addError(NameRange, Error->ET_RegistryAmbiguousOverload);
687       return VariantMatcher();
688     }
689     return Constructed[0];
690   }
691 
692   bool isVariadic() const override {
693     bool Overload0Variadic = Overloads[0]->isVariadic();
694 #ifndef NDEBUG
695     for (const auto &O : Overloads) {
696       assert(Overload0Variadic == O->isVariadic());
697     }
698 #endif
699     return Overload0Variadic;
700   }
701 
702   unsigned getNumArgs() const override {
703     unsigned Overload0NumArgs = Overloads[0]->getNumArgs();
704 #ifndef NDEBUG
705     for (const auto &O : Overloads) {
706       assert(Overload0NumArgs == O->getNumArgs());
707     }
708 #endif
709     return Overload0NumArgs;
710   }
711 
712   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
713                    std::vector<ArgKind> &Kinds) const override {
714     for (const auto &O : Overloads) {
715       if (O->isConvertibleTo(ThisKind))
716         O->getArgKinds(ThisKind, ArgNo, Kinds);
717     }
718   }
719 
720   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
721                        ASTNodeKind *LeastDerivedKind) const override {
722     for (const auto &O : Overloads) {
723       if (O->isConvertibleTo(Kind, Specificity, LeastDerivedKind))
724         return true;
725     }
726     return false;
727   }
728 
729 private:
730   std::vector<std::unique_ptr<MatcherDescriptor>> Overloads;
731 };
732 
733 template <typename ReturnType>
734 class RegexMatcherDescriptor : public MatcherDescriptor {
735 public:
736   RegexMatcherDescriptor(ReturnType (*WithFlags)(StringRef,
737                                                  llvm::Regex::RegexFlags),
738                          ReturnType (*NoFlags)(StringRef),
739                          ArrayRef<ASTNodeKind> RetKinds)
740       : WithFlags(WithFlags), NoFlags(NoFlags),
741         RetKinds(RetKinds.begin(), RetKinds.end()) {}
742   bool isVariadic() const override { return true; }
743   unsigned getNumArgs() const override { return 0; }
744 
745   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
746                    std::vector<ArgKind> &Kinds) const override {
747     assert(ArgNo < 2);
748     Kinds.push_back(ArgKind::AK_String);
749   }
750 
751   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
752                        ASTNodeKind *LeastDerivedKind) const override {
753     return isRetKindConvertibleTo(RetKinds, Kind, Specificity,
754                                   LeastDerivedKind);
755   }
756 
757   VariantMatcher create(SourceRange NameRange, ArrayRef<ParserValue> Args,
758                         Diagnostics *Error) const override {
759     if (Args.size() < 1 || Args.size() > 2) {
760       Error->addError(NameRange, Diagnostics::ET_RegistryWrongArgCount)
761           << "1 or 2" << Args.size();
762       return VariantMatcher();
763     }
764     if (!ArgTypeTraits<StringRef>::is(Args[0].Value)) {
765       Error->addError(Args[0].Range, Error->ET_RegistryWrongArgType)
766           << 1 << ArgTypeTraits<StringRef>::getKind().asString()
767           << Args[0].Value.getTypeAsString();
768       return VariantMatcher();
769     }
770     if (Args.size() == 1) {
771       return outvalueToVariantMatcher(
772           NoFlags(ArgTypeTraits<StringRef>::get(Args[0].Value)));
773     }
774     if (!ArgTypeTraits<llvm::Regex::RegexFlags>::is(Args[1].Value)) {
775       if (llvm::Optional<std::string> BestGuess =
776               ArgTypeTraits<llvm::Regex::RegexFlags>::getBestGuess(
777                   Args[1].Value)) {
778         Error->addError(Args[1].Range, Error->ET_RegistryUnknownEnumWithReplace)
779             << 2 << Args[1].Value.getString() << *BestGuess;
780       } else {
781         Error->addError(Args[1].Range, Error->ET_RegistryWrongArgType)
782             << 2 << ArgTypeTraits<llvm::Regex::RegexFlags>::getKind().asString()
783             << Args[1].Value.getTypeAsString();
784       }
785       return VariantMatcher();
786     }
787     return outvalueToVariantMatcher(
788         WithFlags(ArgTypeTraits<StringRef>::get(Args[0].Value),
789                   ArgTypeTraits<llvm::Regex::RegexFlags>::get(Args[1].Value)));
790   }
791 
792 private:
793   ReturnType (*const WithFlags)(StringRef, llvm::Regex::RegexFlags);
794   ReturnType (*const NoFlags)(StringRef);
795   const std::vector<ASTNodeKind> RetKinds;
796 };
797 
798 /// Variadic operator marshaller function.
799 class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
800 public:
801   using VarOp = DynTypedMatcher::VariadicOperator;
802 
803   VariadicOperatorMatcherDescriptor(unsigned MinCount, unsigned MaxCount,
804                                     VarOp Op, StringRef MatcherName)
805       : MinCount(MinCount), MaxCount(MaxCount), Op(Op),
806         MatcherName(MatcherName) {}
807 
808   VariantMatcher create(SourceRange NameRange,
809                         ArrayRef<ParserValue> Args,
810                         Diagnostics *Error) const override {
811     if (Args.size() < MinCount || MaxCount < Args.size()) {
812       const std::string MaxStr =
813           (MaxCount == std::numeric_limits<unsigned>::max() ? ""
814                                                             : Twine(MaxCount))
815               .str();
816       Error->addError(NameRange, Error->ET_RegistryWrongArgCount)
817           << ("(" + Twine(MinCount) + ", " + MaxStr + ")") << Args.size();
818       return VariantMatcher();
819     }
820 
821     std::vector<VariantMatcher> InnerArgs;
822     for (size_t i = 0, e = Args.size(); i != e; ++i) {
823       const ParserValue &Arg = Args[i];
824       const VariantValue &Value = Arg.Value;
825       if (!Value.isMatcher()) {
826         Error->addError(Arg.Range, Error->ET_RegistryWrongArgType)
827             << (i + 1) << "Matcher<>" << Value.getTypeAsString();
828         return VariantMatcher();
829       }
830       InnerArgs.push_back(Value.getMatcher());
831     }
832     return VariantMatcher::VariadicOperatorMatcher(Op, std::move(InnerArgs));
833   }
834 
835   bool isVariadic() const override { return true; }
836   unsigned getNumArgs() const override { return 0; }
837 
838   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
839                    std::vector<ArgKind> &Kinds) const override {
840     Kinds.push_back(ThisKind);
841   }
842 
843   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
844                        ASTNodeKind *LeastDerivedKind) const override {
845     if (Specificity)
846       *Specificity = 1;
847     if (LeastDerivedKind)
848       *LeastDerivedKind = Kind;
849     return true;
850   }
851 
852   bool isPolymorphic() const override { return true; }
853 
854 private:
855   const unsigned MinCount;
856   const unsigned MaxCount;
857   const VarOp Op;
858   const StringRef MatcherName;
859 };
860 
861 /// Helper functions to select the appropriate marshaller functions.
862 /// They detect the number of arguments, arguments types and return type.
863 
864 /// 0-arg overload
865 template <typename ReturnType>
866 std::unique_ptr<MatcherDescriptor>
867 makeMatcherAutoMarshall(ReturnType (*Func)(), StringRef MatcherName) {
868   std::vector<ASTNodeKind> RetTypes;
869   BuildReturnTypeVector<ReturnType>::build(RetTypes);
870   return std::make_unique<FixedArgCountMatcherDescriptor>(
871       matcherMarshall0<ReturnType>, reinterpret_cast<void (*)()>(Func),
872       MatcherName, RetTypes, None);
873 }
874 
875 /// 1-arg overload
876 template <typename ReturnType, typename ArgType1>
877 std::unique_ptr<MatcherDescriptor>
878 makeMatcherAutoMarshall(ReturnType (*Func)(ArgType1), StringRef MatcherName) {
879   std::vector<ASTNodeKind> RetTypes;
880   BuildReturnTypeVector<ReturnType>::build(RetTypes);
881   ArgKind AK = ArgTypeTraits<ArgType1>::getKind();
882   return std::make_unique<FixedArgCountMatcherDescriptor>(
883       matcherMarshall1<ReturnType, ArgType1>,
884       reinterpret_cast<void (*)()>(Func), MatcherName, RetTypes, AK);
885 }
886 
887 /// 2-arg overload
888 template <typename ReturnType, typename ArgType1, typename ArgType2>
889 std::unique_ptr<MatcherDescriptor>
890 makeMatcherAutoMarshall(ReturnType (*Func)(ArgType1, ArgType2),
891                         StringRef MatcherName) {
892   std::vector<ASTNodeKind> RetTypes;
893   BuildReturnTypeVector<ReturnType>::build(RetTypes);
894   ArgKind AKs[] = { ArgTypeTraits<ArgType1>::getKind(),
895                     ArgTypeTraits<ArgType2>::getKind() };
896   return std::make_unique<FixedArgCountMatcherDescriptor>(
897       matcherMarshall2<ReturnType, ArgType1, ArgType2>,
898       reinterpret_cast<void (*)()>(Func), MatcherName, RetTypes, AKs);
899 }
900 
901 template <typename ReturnType>
902 std::unique_ptr<MatcherDescriptor> makeMatcherRegexMarshall(
903     ReturnType (*FuncFlags)(llvm::StringRef, llvm::Regex::RegexFlags),
904     ReturnType (*Func)(llvm::StringRef)) {
905   std::vector<ASTNodeKind> RetTypes;
906   BuildReturnTypeVector<ReturnType>::build(RetTypes);
907   return std::make_unique<RegexMatcherDescriptor<ReturnType>>(FuncFlags, Func,
908                                                               RetTypes);
909 }
910 
911 /// Variadic overload.
912 template <typename ResultT, typename ArgT,
913           ResultT (*Func)(ArrayRef<const ArgT *>)>
914 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
915     ast_matchers::internal::VariadicFunction<ResultT, ArgT, Func> VarFunc,
916     StringRef MatcherName) {
917   return std::make_unique<VariadicFuncMatcherDescriptor>(VarFunc, MatcherName);
918 }
919 
920 /// Overload for VariadicDynCastAllOfMatchers.
921 ///
922 /// Not strictly necessary, but DynCastAllOfMatcherDescriptor gives us better
923 /// completion results for that type of matcher.
924 template <typename BaseT, typename DerivedT>
925 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
926     ast_matchers::internal::VariadicDynCastAllOfMatcher<BaseT, DerivedT>
927         VarFunc,
928     StringRef MatcherName) {
929   return std::make_unique<DynCastAllOfMatcherDescriptor>(VarFunc, MatcherName);
930 }
931 
932 /// Argument adaptative overload.
933 template <template <typename ToArg, typename FromArg> class ArgumentAdapterT,
934           typename FromTypes, typename ToTypes>
935 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
936     ast_matchers::internal::ArgumentAdaptingMatcherFunc<ArgumentAdapterT,
937                                                         FromTypes, ToTypes>,
938     StringRef MatcherName) {
939   std::vector<std::unique_ptr<MatcherDescriptor>> Overloads;
940   AdaptativeOverloadCollector<ArgumentAdapterT, FromTypes, ToTypes>(MatcherName,
941                                                                     Overloads);
942   return std::make_unique<OverloadedMatcherDescriptor>(Overloads);
943 }
944 
945 template <template <typename ToArg, typename FromArg> class ArgumentAdapterT,
946           typename FromTypes, typename ToTypes>
947 template <typename FromTypeList>
948 inline void AdaptativeOverloadCollector<ArgumentAdapterT, FromTypes,
949                                         ToTypes>::collect(FromTypeList) {
950   Out.push_back(makeMatcherAutoMarshall(
951       &AdaptativeFunc::template create<typename FromTypeList::head>, Name));
952   collect(typename FromTypeList::tail());
953 }
954 
955 /// Variadic operator overload.
956 template <unsigned MinCount, unsigned MaxCount>
957 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
958     ast_matchers::internal::VariadicOperatorMatcherFunc<MinCount, MaxCount>
959         Func,
960     StringRef MatcherName) {
961   return std::make_unique<VariadicOperatorMatcherDescriptor>(
962       MinCount, MaxCount, Func.Op, MatcherName);
963 }
964 
965 } // namespace internal
966 } // namespace dynamic
967 } // namespace ast_matchers
968 } // namespace clang
969 
970 #endif // LLVM_CLANG_AST_MATCHERS_DYNAMIC_MARSHALLERS_H
971