1 //===----------------------------------------------------------------------===//
2 //
3 // Copyright (c) 2012, 2013, 2014, 2015, 2017, 2019 The University of Utah
4 // All rights reserved.
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License.  See the file COPYING for details.
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #if HAVE_CONFIG_H
12 #  include <config.h>
13 #endif
14 
15 #include "SimplifyDependentTypedef.h"
16 
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/ASTContext.h"
19 #include "clang/Basic/SourceManager.h"
20 
21 #include "TransformationManager.h"
22 
23 using namespace clang;
24 
25 static const char *DescriptionMsg =
26 "Sometimes the underlying type of a typedef declaration \
27 is a complex dependent type which prevents further reduction. \
28 This pass tries to replace this complex dependent type with a \
29 simple one. For example, from \n\
30   template<typename T> class { \n\
31     typedef typename X< typename Y<T> >::type type; \n\
32   }; \n\
33 to \n\
34   template<typename T> class { \n\
35     typedef T type; \n\
36   };\n";
37 
38 static RegisterTransformation<SimplifyDependentTypedef>
39          Trans("simplify-dependent-typedef", DescriptionMsg);
40 
41 class DependentTypedefCollectionVisitor : public
42   RecursiveASTVisitor<DependentTypedefCollectionVisitor> {
43 
44 public:
DependentTypedefCollectionVisitor(SimplifyDependentTypedef * Instance)45   explicit DependentTypedefCollectionVisitor(SimplifyDependentTypedef *Instance)
46     : ConsumerInstance(Instance)
47   { }
48 
49   bool VisitTypedefDecl(TypedefDecl *D);
50 
51 private:
52   SimplifyDependentTypedef *ConsumerInstance;
53 
54 };
55 
56 class DependentTypedefTemplateTypeParmTypeVisitor : public
57   RecursiveASTVisitor<DependentTypedefTemplateTypeParmTypeVisitor> {
58 
59   typedef llvm::SmallPtrSet<const clang::Type *, 5> TemplateTypeParmTypeSet;
60 
61 public:
DependentTypedefTemplateTypeParmTypeVisitor(SimplifyDependentTypedef * Instance)62   explicit DependentTypedefTemplateTypeParmTypeVisitor(
63              SimplifyDependentTypedef *Instance)
64     : IsValidType(false)
65   { }
66 
67   bool VisitTemplateTypeParmType(TemplateTypeParmType *Ty);
68 
setTypeSet(TemplateTypeParmTypeSet * Set)69   void setTypeSet(TemplateTypeParmTypeSet *Set) {
70     TypeSet = Set;
71   }
72 
setValidType(bool Valid)73   void setValidType(bool Valid) {
74     IsValidType = Valid;
75   }
76 
isValidType(void)77   bool isValidType(void) {
78     return IsValidType;
79   }
80 
81 private:
82 
83   TemplateTypeParmTypeSet *TypeSet;
84 
85   bool IsValidType;
86 };
87 
VisitTypedefDecl(TypedefDecl * D)88 bool DependentTypedefCollectionVisitor::VisitTypedefDecl(TypedefDecl *D)
89 {
90   ConsumerInstance->handleOneTypedefDecl(D);
91   return true;
92 }
93 
VisitTemplateTypeParmType(TemplateTypeParmType * Ty)94 bool DependentTypedefTemplateTypeParmTypeVisitor::VisitTemplateTypeParmType(
95        TemplateTypeParmType *Ty)
96 {
97   const Type *CanonicalTy =
98       Ty->getCanonicalTypeInternal().getTypePtr();
99   if (TypeSet->count(CanonicalTy)) {
100     IsValidType = true;
101     return false;
102   }
103   return true;
104 }
105 
Initialize(ASTContext & context)106 void SimplifyDependentTypedef::Initialize(ASTContext &context)
107 {
108   Transformation::Initialize(context);
109   CollectionVisitor = new DependentTypedefCollectionVisitor(this);
110   TemplateTypeParmTypeVisitor =
111     new DependentTypedefTemplateTypeParmTypeVisitor(this);
112 }
113 
HandleTranslationUnit(ASTContext & Ctx)114 void SimplifyDependentTypedef::HandleTranslationUnit(ASTContext &Ctx)
115 {
116   if (TransformationManager::isCLangOpt() ||
117       TransformationManager::isOpenCLLangOpt()) {
118     ValidInstanceNum = 0;
119   }
120 
121   CollectionVisitor->TraverseDecl(Ctx.getTranslationUnitDecl());
122 
123   if (QueryInstanceOnly)
124     return;
125 
126   if (TransformationCounter > ValidInstanceNum) {
127     TransError = TransMaxInstanceError;
128     return;
129   }
130 
131   Ctx.getDiagnostics().setSuppressAllDiagnostics(false);
132   TransAssert(TheTypedefDecl && "NULL TheTypedefDecl!");
133   TransAssert(FirstTmplTypeParmD && "NULL FirstTmplTypeParmD!");
134   rewriteTypedefDecl();
135 
136   if (Ctx.getDiagnostics().hasErrorOccurred() ||
137       Ctx.getDiagnostics().hasFatalErrorOccurred())
138     TransError = TransInternalError;
139 }
140 
rewriteTypedefDecl(void)141 void SimplifyDependentTypedef::rewriteTypedefDecl(void)
142 {
143   SourceLocation LocStart = TheTypedefDecl->getBeginLoc();
144 
145   // skip "typedef "
146   LocStart = LocStart.getLocWithOffset(8);
147   SourceLocation LocEnd = TheTypedefDecl->getLocation();
148   LocEnd = LocEnd.getLocWithOffset(-1);
149 
150   std::string ParmName = FirstTmplTypeParmD->getNameAsString();
151   TransAssert(!ParmName.empty() && "Invalid TypeParmType Name!");
152   // make an explicit blank after the type name in case we
153   // have typedef XXX<T>type;
154   TheRewriter.ReplaceText(SourceRange(LocStart, LocEnd), ParmName+" ");
155 }
156 
handleOneTypedefDecl(const TypedefDecl * D)157 void SimplifyDependentTypedef::handleOneTypedefDecl(const TypedefDecl *D)
158 {
159   if (isInIncludedFile(D))
160     return;
161 
162   const TypedefDecl *CanonicalD = dyn_cast<TypedefDecl>(D->getCanonicalDecl());
163   TransAssert(CanonicalD && "Bad TypedefDecl!");
164   if (VisitedTypedefDecls.count(CanonicalD))
165     return;
166   VisitedTypedefDecls.insert(CanonicalD);
167 
168   const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(D->getDeclContext());
169   if (!CXXRD)
170     return;
171 
172   const ClassTemplateDecl *TmplD = CXXRD->getDescribedClassTemplate();
173   if (!TmplD)
174     return;
175 
176   TemplateParameterList *TmplParmList = TmplD->getTemplateParameters();
177   if (TmplParmList->size() == 0)
178     return;
179 
180   TemplateTypeParmTypeSet TypeSet;
181   const TemplateTypeParmDecl *FirstParmD = NULL;
182   for (TemplateParameterList::iterator I = TmplParmList->begin(),
183        E = TmplParmList->end(); I != E; ++I) {
184     if (const TemplateTypeParmDecl *TmplTypeParmD =
185         dyn_cast<TemplateTypeParmDecl>(*I)) {
186       if (!FirstParmD && !TmplTypeParmD->getNameAsString().empty())
187         FirstParmD = TmplTypeParmD;
188       const TemplateTypeParmType *TmplParmTy =
189         dyn_cast<TemplateTypeParmType>(TmplTypeParmD->getTypeForDecl());
190       TransAssert(TmplParmTy && "Bad TemplateTypeParmType!");
191       TypeSet.insert(TmplParmTy->getCanonicalTypeInternal().getTypePtr());
192     }
193   }
194 
195   if (!FirstParmD)
196     return;
197 
198   QualType QT = CanonicalD->getUnderlyingType();
199   const Type *Ty = QT.getTypePtr();
200   Type::TypeClass TC = Ty->getTypeClass();
201   if ((TC != Type::DependentName) &&
202       (TC != Type::DependentTemplateSpecialization) &&
203       (TC != Type::TemplateSpecialization) &&
204       (TC != Type::Elaborated))
205     return;
206 
207   TemplateTypeParmTypeVisitor->setTypeSet(&TypeSet);
208   TemplateTypeParmTypeVisitor->setValidType(false);
209   TemplateTypeParmTypeVisitor->TraverseType(QT);
210 
211   if (!TemplateTypeParmTypeVisitor->isValidType())
212     return;
213 
214   ValidInstanceNum++;
215   if (ValidInstanceNum != TransformationCounter)
216     return;
217 
218   FirstTmplTypeParmD = FirstParmD;
219   TheTypedefDecl = CanonicalD;
220 }
221 
~SimplifyDependentTypedef(void)222 SimplifyDependentTypedef::~SimplifyDependentTypedef(void)
223 {
224   delete CollectionVisitor;
225   delete TemplateTypeParmTypeVisitor;
226 }
227 
228