1 //===----------------------------------------------------------------------===//
2 //
3 // Copyright (c) 2012, 2013, 2014, 2015, 2016, 2017 The University of Utah
4 // All rights reserved.
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License.  See the file COPYING for details.
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #if HAVE_CONFIG_H
12 #  include <config.h>
13 #endif
14 
15 #include "ReduceClassTemplateParameter.h"
16 
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "clang/AST/RecursiveASTVisitor.h"
20 #include "clang/AST/ASTContext.h"
21 #include "clang/Basic/SourceManager.h"
22 
23 #include "TransformationManager.h"
24 
25 using namespace clang;
26 
27 static const char *DescriptionMsg =
28 "This pass tries to remove one unused parameter from a class template \
29 declaration and also erase the corresponding template argument \
30 from template instantiations/specializations. Note that this pass \
31 does not target those templates with single argument, and skips \
32 variadic templates as well. ";
33 
34 static RegisterTransformation<ReduceClassTemplateParameter>
35          Trans("reduce-class-template-param", DescriptionMsg);
36 
37 class ReduceClassTemplateParameterASTVisitor : public
38   RecursiveASTVisitor<ReduceClassTemplateParameterASTVisitor> {
39 
40 public:
ReduceClassTemplateParameterASTVisitor(ReduceClassTemplateParameter * Instance)41   explicit ReduceClassTemplateParameterASTVisitor(
42              ReduceClassTemplateParameter *Instance)
43     : ConsumerInstance(Instance)
44   { }
45 
46   bool VisitClassTemplateDecl(ClassTemplateDecl *D);
47 
48 private:
49   ReduceClassTemplateParameter *ConsumerInstance;
50 
51 };
52 
53 namespace {
54 
55 typedef llvm::SmallPtrSet<const NamedDecl *, 8> TemplateParameterSet;
56 
57 class TemplateParameterVisitor : public
58   RecursiveASTVisitor<TemplateParameterVisitor> {
59 
60 public:
TemplateParameterVisitor(TemplateParameterSet & Params)61   explicit TemplateParameterVisitor(TemplateParameterSet &Params)
62              : UsedParameters(Params)
63   { }
64 
~TemplateParameterVisitor()65   ~TemplateParameterVisitor() { };
66 
67   bool VisitTemplateTypeParmType(TemplateTypeParmType *Ty);
68 
69 private:
70 
71   TemplateParameterSet &UsedParameters;
72 };
73 
VisitTemplateTypeParmType(TemplateTypeParmType * Ty)74 bool TemplateParameterVisitor::VisitTemplateTypeParmType(
75        TemplateTypeParmType *Ty)
76 {
77   const TemplateTypeParmDecl *D = Ty->getDecl();
78   UsedParameters.insert(D);
79   return true;
80 }
81 
82 class ArgumentDependencyVisitor : public
83   RecursiveASTVisitor<ArgumentDependencyVisitor> {
84 
85 public:
86   typedef llvm::DenseMap<const Type *, unsigned> TypeToVisitsCountSet;
87 
ArgumentDependencyVisitor(TypeToVisitsCountSet & CounterSet)88   explicit ArgumentDependencyVisitor(TypeToVisitsCountSet &CounterSet)
89              : VisitsCountSet(CounterSet)
90   { }
91 
92   bool VisitTemplateTypeParmType(TemplateTypeParmType *Ty);
93 
94 private:
95 
96   TypeToVisitsCountSet &VisitsCountSet;
97 };
98 
VisitTemplateTypeParmType(TemplateTypeParmType * Ty)99 bool ArgumentDependencyVisitor::VisitTemplateTypeParmType(
100        TemplateTypeParmType *Ty)
101 {
102   TypeToVisitsCountSet::iterator I = VisitsCountSet.find(Ty);
103   if (I != VisitsCountSet.end()) {
104     unsigned Count = (*I).second + 1;
105     VisitsCountSet[(*I).first] = Count;
106   }
107   return true;
108 }
109 
110 class ClassTemplateMethodVisitor : public
111   RecursiveASTVisitor<ClassTemplateMethodVisitor> {
112 
113 public:
ClassTemplateMethodVisitor(ReduceClassTemplateParameter * Instance,unsigned Idx)114   ClassTemplateMethodVisitor(ReduceClassTemplateParameter *Instance,
115                              unsigned Idx)
116     : ConsumerInstance(Instance), TheParameterIndex(Idx)
117   { }
118 
119   bool VisitFunctionDecl(FunctionDecl *FD);
120 
121 private:
122   ReduceClassTemplateParameter *ConsumerInstance;
123 
124   unsigned TheParameterIndex;
125 };
126 
VisitFunctionDecl(FunctionDecl * FD)127 bool ClassTemplateMethodVisitor::VisitFunctionDecl(FunctionDecl *FD)
128 {
129   FunctionTemplateDecl *TD = FD->getDescribedFunctionTemplate();
130   for (FunctionDecl::redecl_iterator I = FD->redecls_begin(),
131        E = FD->redecls_end(); I != E; ++I) {
132     unsigned Num = (*I)->getNumTemplateParameterLists();
133     for (unsigned Idx = 0; Idx < Num; ++Idx) {
134       const TemplateParameterList *TPList = (*I)->getTemplateParameterList(Idx);
135       // We don't want to mistakenly rewrite template parameters associated
136       // with the FD if FD is a function template.
137       if (TD && TPList == TD->getTemplateParameters())
138         continue;
139       const NamedDecl *Param = TPList->getParam(TheParameterIndex);
140       SourceRange Range = Param->getSourceRange();
141       ConsumerInstance->removeParameterByRange(Range, TPList,
142                                                TheParameterIndex);
143     }
144   }
145   return true;
146 }
147 
148 }
149 
150 class ReduceClassTemplateParameterRewriteVisitor : public
151   RecursiveASTVisitor<ReduceClassTemplateParameterRewriteVisitor> {
152 
153 public:
ReduceClassTemplateParameterRewriteVisitor(ReduceClassTemplateParameter * Instance)154   explicit ReduceClassTemplateParameterRewriteVisitor(
155              ReduceClassTemplateParameter *Instance)
156     : ConsumerInstance(Instance)
157   { }
158 
159   bool VisitTemplateSpecializationTypeLoc(TemplateSpecializationTypeLoc Loc);
160 
161 private:
162 
163   ReduceClassTemplateParameter *ConsumerInstance;
164 };
165 
166 bool ReduceClassTemplateParameterRewriteVisitor::
VisitTemplateSpecializationTypeLoc(TemplateSpecializationTypeLoc Loc)167        VisitTemplateSpecializationTypeLoc(TemplateSpecializationTypeLoc Loc)
168 {
169   // Invalidation can be introduced by constructor's initialization list, e.g.:
170   // template<typename T1, typename T2> class A { };
171   // class B : public A<int, int> {
172   //   int m;
173   //   B(int x) : m(x) {}
174   // };
175   // In RecursiveASTVisitor.h, TraverseConstructorInitializer will visit the part
176   // of initializing base class's, i.e. through base's default constructor
177   if (Loc.getBeginLoc().isInvalid())
178     return true;
179   const TemplateSpecializationType *Ty =
180     dyn_cast<TemplateSpecializationType>(Loc.getTypePtr());
181   TransAssert(Ty && "Invalid TemplateSpecializationType!");
182 
183   TemplateName TmplName = Ty->getTemplateName();
184   if (!ConsumerInstance->referToTheTemplateDecl(TmplName))
185     return true;
186 
187   unsigned NumArgs = Loc.getNumArgs();
188   // I would put a stronger assert here, i.e.,
189   // " (ConsumerInstance->TheParameterIndex >= NumArgs) &&
190   // ConsumerInstance->hasDefaultArg "
191   // but sometimes ill-formed input could yield incomplete
192   // info, e.g., for two template decls which refer to the same
193   // template def, one decl could have a non-null default arg,
194   // while another decl's default arg field could be null.
195   if (ConsumerInstance->TheParameterIndex >= NumArgs)
196     return true;
197 
198   TransAssert((ConsumerInstance->TheParameterIndex < NumArgs) &&
199               "TheParameterIndex cannot be greater than NumArgs!");
200   TemplateArgumentLoc ArgLoc = Loc.getArgLoc(ConsumerInstance->TheParameterIndex);
201   SourceRange Range = ArgLoc.getSourceRange();
202 
203   if (NumArgs == 1) {
204     ConsumerInstance->TheRewriter.ReplaceText(SourceRange(Loc.getLAngleLoc(),
205                                                           Loc.getRAngleLoc()),
206                                               "<>");
207   }
208   else if ((ConsumerInstance->TheParameterIndex + 1) == NumArgs) {
209     SourceLocation EndLoc = Loc.getRAngleLoc();
210     EndLoc = EndLoc.getLocWithOffset(-1);
211     ConsumerInstance->RewriteHelper->removeTextFromLeftAt(
212                                        Range, ',', EndLoc);
213   }
214   else {
215     ConsumerInstance->RewriteHelper->removeTextUntil(Range, ',');
216   }
217   return true;
218 }
219 
VisitClassTemplateDecl(ClassTemplateDecl * D)220 bool ReduceClassTemplateParameterASTVisitor::VisitClassTemplateDecl(
221        ClassTemplateDecl *D)
222 {
223   if (ConsumerInstance->isInIncludedFile(D))
224     return true;
225 
226   ClassTemplateDecl *CanonicalD = D->getCanonicalDecl();
227   if (ConsumerInstance->VisitedDecls.count(CanonicalD))
228     return true;
229 
230   ConsumerInstance->VisitedDecls.insert(CanonicalD);
231   if (!ConsumerInstance->isValidClassTemplateDecl(D))
232     return true;
233 
234   TemplateParameterSet ParamsSet;
235   TemplateParameterVisitor ParameterVisitor(ParamsSet);
236   CXXRecordDecl *CXXRD = D->getTemplatedDecl();
237   CXXRecordDecl *Def = CXXRD->getDefinition();
238   if (Def)
239     ParameterVisitor.TraverseDecl(Def);
240 
241   // ISSUE: we should also check the parameter usage for partial template
242   //        specializations. For example:
243   //   template<typename T1, typename T2> struct S{};
244   //   template<typename T1, typename T2> struct<T1 *, T2 *> S{...};
245   //   T1 or T2 could be used in "..."
246   // Also, we could have another bad transformation, for example,
247   //   template<bool, typename T> struct S{};
248   //   template<typename T> struct<true, T> S{};
249   // if we remove bool and true, we will have two definitions for S
250   TemplateParameterList *TPList;
251   if (Def) {
252     // make sure we use the params as in ParameterVisitor
253     const ClassTemplateDecl *CT = Def->getDescribedClassTemplate();
254     TransAssert(CT && "NULL DescribedClassTemplate!");
255     TPList = CT->getTemplateParameters();
256   }
257   else {
258     TPList = CanonicalD->getTemplateParameters();
259   }
260 
261   unsigned Index = 0;
262   for (TemplateParameterList::const_iterator I = TPList->begin(),
263        E = TPList->end(); I != E; ++I) {
264     const NamedDecl *ND = (*I);
265     if (ParamsSet.count(ND)) {
266       Index++;
267       continue;
268     }
269 
270     ConsumerInstance->ValidInstanceNum++;
271     if (ConsumerInstance->ValidInstanceNum ==
272         ConsumerInstance->TransformationCounter) {
273       ConsumerInstance->TheClassTemplateDecl = CanonicalD;
274       ConsumerInstance->TheParameterIndex = Index;
275       ConsumerInstance->TheTemplateName = new TemplateName(CanonicalD);
276       ConsumerInstance->setDefaultArgFlag(ND);
277     }
278     Index++;
279   }
280 
281   return true;
282 }
283 
Initialize(ASTContext & context)284 void ReduceClassTemplateParameter::Initialize(ASTContext &context)
285 {
286   Transformation::Initialize(context);
287   CollectionVisitor = new ReduceClassTemplateParameterASTVisitor(this);
288   ArgRewriteVisitor = new ReduceClassTemplateParameterRewriteVisitor(this);
289 }
290 
HandleTranslationUnit(ASTContext & Ctx)291 void ReduceClassTemplateParameter::HandleTranslationUnit(ASTContext &Ctx)
292 {
293   if (TransformationManager::isCLangOpt() ||
294       TransformationManager::isOpenCLLangOpt()) {
295     ValidInstanceNum = 0;
296   }
297   else {
298     CollectionVisitor->TraverseDecl(Ctx.getTranslationUnitDecl());
299   }
300 
301   if (QueryInstanceOnly)
302     return;
303 
304   if (TransformationCounter > ValidInstanceNum) {
305     TransError = TransMaxInstanceError;
306     return;
307   }
308 
309   TransAssert(TheClassTemplateDecl && "NULL TheClassTemplateDecl!");
310   TransAssert(ArgRewriteVisitor && "NULL ArgRewriteVisitor!");
311   Ctx.getDiagnostics().setSuppressAllDiagnostics(false);
312 
313   removeParameterFromDecl();
314   removeParameterFromMethods();
315   removeParameterFromPartialSpecs();
316   ArgRewriteVisitor->TraverseDecl(Ctx.getTranslationUnitDecl());
317 
318   if (Ctx.getDiagnostics().hasErrorOccurred() ||
319       Ctx.getDiagnostics().hasFatalErrorOccurred())
320     TransError = TransInternalError;
321 }
322 
removeParameterByRange(SourceRange Range,const TemplateParameterList * TPList,unsigned Index)323 void ReduceClassTemplateParameter::removeParameterByRange(SourceRange Range,
324                                      const TemplateParameterList *TPList,
325                                      unsigned Index)
326 {
327   unsigned NumParams = TPList->size();
328 
329   // if the parameter is the last one
330   if (NumParams == 1) {
331     TheRewriter.ReplaceText(SourceRange(TPList->getLAngleLoc(),
332                                        TPList->getRAngleLoc()),
333                             "<>");
334   }
335   else if ((Index + 1) == NumParams) {
336     SourceLocation EndLoc = TPList->getRAngleLoc();
337     EndLoc = EndLoc.getLocWithOffset(-1);
338     RewriteHelper->removeTextFromLeftAt(Range, ',', EndLoc);
339   }
340   else {
341     RewriteHelper->removeTextUntil(Range, ',');
342   }
343 }
344 
removeParameterFromDecl()345 void ReduceClassTemplateParameter::removeParameterFromDecl()
346 {
347   unsigned NumParams = TheClassTemplateDecl->getTemplateParameters()->size();
348 
349   TransAssert((NumParams > 1) && "Bad size of TheClassTemplateDecl!");
350   (void)NumParams;
351 
352   for (ClassTemplateDecl::redecl_iterator
353          I = TheClassTemplateDecl->redecls_begin(),
354          E = TheClassTemplateDecl->redecls_end();
355        I != E; ++I) {
356     const TemplateParameterList *TPList = (*I)->getTemplateParameters();
357     const NamedDecl *Param = TPList->getParam(TheParameterIndex);
358     SourceRange Range = Param->getSourceRange();
359     removeParameterByRange(Range, TPList, TheParameterIndex);
360   }
361 }
362 
removeParameterFromMethods()363 void ReduceClassTemplateParameter::removeParameterFromMethods()
364 {
365   CXXRecordDecl *CXXRD = TheClassTemplateDecl->getTemplatedDecl();
366   for (auto I = CXXRD->method_begin(), E = CXXRD->method_end();
367        I != E; ++I) {
368     ClassTemplateMethodVisitor V(this, TheParameterIndex);
369     V.TraverseDecl(*I);
370   }
371 }
372 
removeOneParameterByArgExpression(const ClassTemplatePartialSpecializationDecl * PartialD,const TemplateArgument & Arg)373 void ReduceClassTemplateParameter::removeOneParameterByArgExpression(
374        const ClassTemplatePartialSpecializationDecl *PartialD,
375        const TemplateArgument &Arg)
376 {
377   TransAssert((Arg.getKind() == TemplateArgument::Expression) &&
378               "Arg is not TemplateArgument::Expression!");
379 
380   const Expr *E = Arg.getAsExpr();
381   TransAssert(E && "Bad Expression!");
382   const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParenCasts());
383   TransAssert(DRE && "Bad DeclRefExpr!");
384   const NonTypeTemplateParmDecl *ParmD =
385     dyn_cast<NonTypeTemplateParmDecl>(DRE->getDecl());
386   TransAssert(ParmD && "Invalid NonTypeTemplateParmDecl!");
387 
388   const TemplateParameterList *TPList = PartialD->getTemplateParameters();
389   unsigned Idx = 0;
390   for (TemplateParameterList::const_iterator I = TPList->begin(),
391        E = TPList->end(); I != E; ++I) {
392     if ((*I) == ParmD)
393       break;
394     Idx++;
395   }
396 
397   unsigned NumParams = TPList->size();
398   TransAssert((Idx < NumParams) && "Cannot find valid TemplateParameter!");
399   (void)NumParams;
400   SourceRange Range = ParmD->getSourceRange();
401   removeParameterByRange(Range, TPList, Idx);
402 }
403 
removeOneParameterByArgType(const ClassTemplatePartialSpecializationDecl * PartialD,const TemplateArgument & Arg)404 void ReduceClassTemplateParameter::removeOneParameterByArgType(
405        const ClassTemplatePartialSpecializationDecl *PartialD,
406        const TemplateArgument &Arg)
407 {
408   TransAssert((Arg.getKind() == TemplateArgument::Type) &&
409               "Arg is not TemplateArgument::Type!");
410   llvm::DenseMap<const Type *, unsigned> TypeToVisitsCount;
411   llvm::DenseMap<const Type *, const NamedDecl *> TypeToNamedDecl;
412   llvm::DenseMap<const Type *, unsigned> TypeToIndex;
413 
414   // retrieve all TemplateTypeParmType
415   const TemplateParameterList *TPList = PartialD->getTemplateParameters();
416   unsigned Idx = 0;
417   for (TemplateParameterList::const_iterator I = TPList->begin(),
418        E = TPList->end(); I != E; ++I) {
419     const NamedDecl *ND = (*I);
420     const TemplateTypeParmDecl *TypeD = dyn_cast<TemplateTypeParmDecl>(ND);
421     if (!TypeD) {
422       Idx++;
423       continue;
424     }
425     const Type *ParmTy = TypeD->getTypeForDecl();
426     TypeToVisitsCount[ParmTy] = 0;
427     TypeToNamedDecl[ParmTy] = ND;
428     TypeToIndex[ParmTy] = Idx;
429   }
430 
431   QualType QTy = Arg.getAsType();
432   ArgumentDependencyVisitor V(TypeToVisitsCount);
433   // collect TemplateTypeParmType being used by Arg
434   V.TraverseType(QTy);
435 
436   llvm::DenseMap<const Type *, unsigned> DependentTypeToVisitsCount;
437   for (llvm::DenseMap<const Type *, unsigned>::iterator
438          I = TypeToVisitsCount.begin(), E = TypeToVisitsCount.end();
439        I != E; ++I) {
440     if ((*I).second > 0)
441       DependentTypeToVisitsCount[(*I).first] = 1;
442   }
443 
444   // check if the used TemplateTypeParmType[s] have dependencies
445   // on other Args. If yes, we cannot remove it from the parameter list.
446   // For example:
447   //   template <typename T>
448   //   struct S <T*, T&> {};
449   // removing either of the arguments needs to keep the template
450   // parameter
451   ArgumentDependencyVisitor AccumV(DependentTypeToVisitsCount);
452 
453   const ASTTemplateArgumentListInfo *ArgList =
454     PartialD->getTemplateArgsAsWritten();
455 
456   const TemplateArgumentLoc *ArgLocs = ArgList->getTemplateArgs();
457   unsigned NumArgs = ArgList->NumTemplateArgs;
458   TransAssert((TheParameterIndex < NumArgs) &&
459                "Bad NumArgs from partial template decl!");
460   for (unsigned I = 0; I < NumArgs; ++I) {
461     if (I == TheParameterIndex)
462       continue;
463 
464     const TemplateArgumentLoc ArgLoc = ArgLocs[I];
465     TemplateArgument OtherArg = ArgLoc.getArgument();
466     if (OtherArg.isInstantiationDependent() &&
467         (OtherArg.getKind() == TemplateArgument::Type)) {
468       QualType QTy = OtherArg.getAsType();
469       AccumV.TraverseType(QTy);
470     }
471   }
472 
473   for (llvm::DenseMap<const Type *, unsigned>::iterator
474          I = DependentTypeToVisitsCount.begin(),
475          E = DependentTypeToVisitsCount.end();
476        I != E; ++I) {
477     if ((*I).second != 1)
478       continue;
479 
480     const NamedDecl *Param = TypeToNamedDecl[(*I).first];
481     TransAssert(Param && "NULL Parameter!");
482     SourceRange Range = Param->getSourceRange();
483     removeParameterByRange(Range, TPList, TypeToIndex[(*I).first]);
484   }
485 }
486 
removeOneParameterByArgTemplate(const ClassTemplatePartialSpecializationDecl * PartialD,const TemplateArgument & Arg)487 void ReduceClassTemplateParameter::removeOneParameterByArgTemplate(
488        const ClassTemplatePartialSpecializationDecl *PartialD,
489        const TemplateArgument &Arg)
490 {
491   TransAssert((Arg.getKind() == TemplateArgument::Template) &&
492               "Arg is not TemplateArgument::Template!");
493   TemplateName TmplName = Arg.getAsTemplate();
494   TransAssert((TmplName.getKind() == TemplateName::Template) &&
495               "Invalid TemplateName Kind!");
496   const TemplateDecl *TmplD = TmplName.getAsTemplateDecl();
497 
498   const TemplateParameterList *TPList = PartialD->getTemplateParameters();
499   unsigned Idx = 0;
500   for (TemplateParameterList::const_iterator I = TPList->begin(),
501        E = TPList->end(); I != E; ++I) {
502     if ((*I) == TmplD)
503       break;
504     Idx++;
505   }
506 
507   unsigned NumParams = TPList->size();
508   TransAssert((Idx < NumParams) && "Cannot find valid TemplateParameter!");
509   (void)NumParams;
510   SourceRange Range = TmplD->getSourceRange();
511   removeParameterByRange(Range, TPList, Idx);
512 
513   return;
514 }
515 
removeOneParameterFromPartialDecl(const ClassTemplatePartialSpecializationDecl * PartialD,const TemplateArgument & Arg)516 void ReduceClassTemplateParameter::removeOneParameterFromPartialDecl(
517        const ClassTemplatePartialSpecializationDecl *PartialD,
518        const TemplateArgument &Arg)
519 {
520   if (!Arg.isInstantiationDependent())
521     return;
522 
523   TemplateArgument::ArgKind K = Arg.getKind();
524   switch (K) {
525   case TemplateArgument::Expression:
526     removeOneParameterByArgExpression(PartialD, Arg);
527     return;
528 
529   case TemplateArgument::Template:
530     removeOneParameterByArgTemplate(PartialD, Arg);
531     return;
532 
533   case TemplateArgument::Type:
534     removeOneParameterByArgType(PartialD, Arg);
535     return;
536 
537   default:
538     TransAssert(0 && "Uncatched ArgKind!");
539   }
540   TransAssert(0 && "Unreachable code!");
541 }
542 
getNamedDecl(const TemplateArgument & Arg)543 const NamedDecl *ReduceClassTemplateParameter::getNamedDecl(
544         const TemplateArgument &Arg)
545 {
546   if (!Arg.isInstantiationDependent())
547     return NULL;
548 
549   TemplateArgument::ArgKind K = Arg.getKind();
550   switch (K) {
551   case TemplateArgument::Expression: {
552     const Expr *E = Arg.getAsExpr();
553     if (const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E)) {
554       return dyn_cast<NonTypeTemplateParmDecl>(DRE->getDecl());
555     }
556     else {
557       return NULL;
558     }
559   }
560 
561   case TemplateArgument::Template: {
562     TemplateName TmplName = Arg.getAsTemplate();
563     TransAssert((TmplName.getKind() == TemplateName::Template) &&
564                 "Invalid TemplateName Kind!");
565     return TmplName.getAsTemplateDecl();
566   }
567 
568   case TemplateArgument::Type: {
569     const Type *Ty = Arg.getAsType().getTypePtr();
570     if (const TemplateTypeParmType *TmplTy =
571         dyn_cast<TemplateTypeParmType>(Ty)) {
572       return TmplTy->getDecl();
573     }
574     else {
575       return NULL;
576     }
577   }
578 
579   default:
580     return NULL;
581   }
582 
583   TransAssert(0 && "Unreachable code!");
584   return NULL;
585 }
586 
referToAParameter(const ClassTemplatePartialSpecializationDecl * PartialD,const TemplateArgument & Arg)587 bool ReduceClassTemplateParameter::referToAParameter(
588        const ClassTemplatePartialSpecializationDecl *PartialD,
589        const TemplateArgument &Arg)
590 {
591   const NamedDecl *ArgND = getNamedDecl(Arg);
592   if (!ArgND)
593     return false;
594 
595   const TemplateParameterList *TPList = PartialD->getTemplateParameters();
596   for (TemplateParameterList::const_iterator PI = TPList->begin(),
597        PE = TPList->end(); PI != PE; ++PI) {
598     if (ArgND != (*PI))
599       return false;
600   }
601   return true;
602 }
603 
isValidForReduction(const ClassTemplatePartialSpecializationDecl * PartialD)604 bool ReduceClassTemplateParameter::isValidForReduction(
605        const ClassTemplatePartialSpecializationDecl *PartialD)
606 {
607   const ASTTemplateArgumentListInfo *ArgList =
608     PartialD->getTemplateArgsAsWritten();
609 
610   unsigned NumArgsAsWritten = ArgList->NumTemplateArgs;
611   unsigned NumArgs = PartialD->getTemplateInstantiationArgs().size();
612 
613   if ((NumArgsAsWritten > 0) &&
614       (TheParameterIndex >= NumArgsAsWritten) &&
615       hasDefaultArg &&
616       ((NumArgsAsWritten + 1) == NumArgs))  {
617 
618     return true;
619   }
620 
621   if (NumArgsAsWritten != NumArgs)
622     return false;
623 
624   const TemplateArgumentLoc *ArgLocs = ArgList->getTemplateArgs();
625   for (unsigned AI = 0; AI < NumArgsAsWritten; ++AI) {
626     if (AI == TheParameterIndex)
627       continue;
628     const TemplateArgumentLoc ArgLoc = ArgLocs[AI];
629     TemplateArgument Arg = ArgLoc.getArgument();
630     if (!referToAParameter(PartialD, Arg))
631       return false;
632   }
633 
634   return true;
635 }
636 
reducePartialSpec(const ClassTemplatePartialSpecializationDecl * PartialD)637 bool ReduceClassTemplateParameter::reducePartialSpec(
638        const ClassTemplatePartialSpecializationDecl *PartialD)
639 {
640   const CXXRecordDecl *CXXRD = TheClassTemplateDecl->getTemplatedDecl();
641   // it CXXRD has definition, skip it to avoid duplication
642   if (CXXRD->hasDefinition())
643     return false;
644 
645   if (!isValidForReduction(PartialD))
646     return false;
647 
648   const ASTTemplateArgumentListInfo *ArgList =
649     PartialD->getTemplateArgsAsWritten();
650   const TemplateArgumentLoc *ArgLocs = ArgList->getTemplateArgs();
651   unsigned NumArgsAsWritten = ArgList->NumTemplateArgs;
652 
653   const TemplateArgumentLoc FirstArgLoc = ArgLocs[0];
654   SourceRange FirstRange = FirstArgLoc.getSourceRange();
655   SourceLocation StartLoc = FirstRange.getBegin();
656 
657   const TemplateArgumentLoc LastArgLoc = ArgLocs[NumArgsAsWritten - 1];
658   SourceRange LastRange = LastArgLoc.getSourceRange();
659   SourceLocation EndLoc =
660     RewriteHelper->getEndLocationUntil(LastRange, '>');
661 
662   RewriteHelper->removeTextFromLeftAt(SourceRange(StartLoc, EndLoc), '<', EndLoc);
663   return true;
664 }
665 
666 // ISSUE: The transformation is known to go wrong in the following case:
667 // template<typename T1, typename T2> struct S;
668 // template<typename T1, typename T2> struct S<T2, T1>;
removeParameterFromPartialSpecs()669 void ReduceClassTemplateParameter::removeParameterFromPartialSpecs()
670 {
671   SmallVector<ClassTemplatePartialSpecializationDecl *, 10> PartialDecls;
672   TheClassTemplateDecl->getPartialSpecializations(PartialDecls);
673 
674   for (SmallVector<ClassTemplatePartialSpecializationDecl *, 10>::iterator
675          I = PartialDecls.begin(), E = PartialDecls.end(); I != E; ++I) {
676     const ClassTemplatePartialSpecializationDecl *PartialD = (*I);
677 
678     const ASTTemplateArgumentListInfo *ArgList =
679       PartialD->getTemplateArgsAsWritten();
680     const TemplateArgumentLoc *ArgLocs = ArgList->getTemplateArgs();
681     unsigned NumArgs = ArgList->NumTemplateArgs;
682 
683     if (!ArgLocs)
684       continue;
685 
686     // handle a special case where we could reduce a partial specialization
687     // to a class template definition, e.g.:
688     //   template<typename T1, typename T2> struct A;
689     //   template<typename T1> struct A<T1, int> { };
690     // ==>
691     //   template<typename T1> struct A;
692     //   template<typename T1> struct A { };
693     if (reducePartialSpec(PartialD))
694       continue;
695 
696     if ((TheParameterIndex >= NumArgs) && hasDefaultArg)
697       return;
698 
699     TransAssert((TheParameterIndex < NumArgs) &&
700                  "Bad NumArgs from partial template decl!");
701     TemplateArgumentLoc ArgLoc = ArgLocs[TheParameterIndex];
702 
703     TemplateArgument Arg = ArgLoc.getArgument();
704     removeOneParameterFromPartialDecl(PartialD, Arg);
705 
706     SourceRange Range = ArgLoc.getSourceRange();
707 
708     if (NumArgs == 1) {
709       SourceLocation StartLoc = Range.getBegin();
710       SourceLocation EndLoc =
711         RewriteHelper->getEndLocationUntil(Range, '>');
712       EndLoc = EndLoc.getLocWithOffset(-1);
713       TheRewriter.RemoveText(SourceRange(StartLoc, EndLoc));
714     }
715     else if ((TheParameterIndex + 1) == NumArgs) {
716       // Seems there is no getRAngleLoc() utility for
717       // template arguments from a partial specialization
718       SourceLocation EndLoc =
719         RewriteHelper->getEndLocationUntil(Range, '>');
720       EndLoc = EndLoc.getLocWithOffset(-1);
721       RewriteHelper->removeTextFromLeftAt(Range, ',', EndLoc);
722     }
723     else {
724       RewriteHelper->removeTextUntil(Range, ',');
725     }
726   }
727 }
728 
isValidClassTemplateDecl(const ClassTemplateDecl * D)729 bool ReduceClassTemplateParameter::isValidClassTemplateDecl(
730                                      const ClassTemplateDecl *D)
731 {
732   const TemplateParameterList *TPList = D->getTemplateParameters();
733   if (TPList->size() <= 1)
734     return false;
735 
736   // FIXME: need to handle parameter pack later
737   for (TemplateParameterList::const_iterator I = TPList->begin(),
738        E = TPList->end(); I != E; ++I) {
739     if (isParameterPack(*I))
740       return false;
741   }
742   return true;
743 }
744 
setDefaultArgFlag(const NamedDecl * ND)745 void ReduceClassTemplateParameter::setDefaultArgFlag(const NamedDecl *ND)
746 {
747   if (const NonTypeTemplateParmDecl *D =
748       dyn_cast<NonTypeTemplateParmDecl>(ND)) {
749     hasDefaultArg = D->hasDefaultArgument();
750   }
751   else if (const TemplateTypeParmDecl *D =
752              dyn_cast<TemplateTypeParmDecl>(ND)) {
753     hasDefaultArg = D->hasDefaultArgument();
754   }
755   else if (const TemplateTemplateParmDecl *D =
756              dyn_cast<TemplateTemplateParmDecl>(ND)) {
757     hasDefaultArg = D->hasDefaultArgument();
758   }
759   else {
760     TransAssert(0 && "Unknown template parameter type!");
761   }
762 }
763 
referToTheTemplateDecl(TemplateName TmplName)764 bool ReduceClassTemplateParameter::referToTheTemplateDecl(
765                                      TemplateName TmplName)
766 {
767   return Context->hasSameTemplateName(*TheTemplateName, TmplName);
768 }
769 
~ReduceClassTemplateParameter()770 ReduceClassTemplateParameter::~ReduceClassTemplateParameter()
771 {
772   delete TheTemplateName;
773   delete CollectionVisitor;
774   delete ArgRewriteVisitor;
775 }
776 
777