1 #include "UsePointApisCheck.h"
2 
3 #include <algorithm>
4 #include <clang/AST/ASTContext.h>
5 #include <clang/AST/Decl.h>
6 #include <clang/AST/DeclBase.h>
7 #include <clang/AST/DeclTemplate.h>
8 #include <clang/AST/Expr.h>
9 #include <clang/AST/ExprCXX.h>
10 #include <clang/AST/Type.h>
11 #include <clang/ASTMatchers/ASTMatchFinder.h>
12 #include <clang/ASTMatchers/ASTMatchers.h>
13 #include <clang/ASTMatchers/ASTMatchersInternal.h>
14 #include <clang/Basic/Diagnostic.h>
15 #include <clang/Basic/DiagnosticIDs.h>
16 #include <clang/Basic/LLVM.h>
17 #include <clang/Basic/SourceLocation.h>
18 #include <clang/Lex/Lexer.h>
19 #include <climits>
20 #include <llvm/ADT/Twine.h>
21 #include <llvm/Support/Casting.h>
22 #include <string>
23 
24 #include "Utils.h"
25 #include "clang/Basic/OperatorKinds.h"
26 
27 using namespace clang::ast_matchers;
28 
29 namespace clang
30 {
31 namespace tidy
32 {
33 namespace cata
34 {
35 
registerMatchers(MatchFinder * Finder)36 void UsePointApisCheck::registerMatchers( MatchFinder *Finder )
37 {
38     Finder->addMatcher(
39         callExpr(
40             forEachArgumentWithParam(
41                 expr().bind( "xarg" ),
42                 parmVarDecl( hasType( isInteger() ), isXParam() ).bind( "xparam" )
43             ),
44             callee( functionDecl().bind( "callee" ) )
45         ).bind( "call" ),
46         this
47     );
48     Finder->addMatcher(
49         cxxConstructExpr(
50             forEachArgumentWithParam(
51                 expr().bind( "xarg" ),
52                 parmVarDecl(
53                     anyOf( hasType( asString( "int" ) ), hasType( asString( "const int" ) ) ),
54                     isXParam()
55                 ).bind( "xparam" )
56             ),
57             hasDeclaration(
58                 cxxMethodDecl( unless( ofClass( isPointOrCoordPointType() ) ) ).bind( "callee" )
59             )
60         ).bind( "constructorCall" ),
61         this
62     );
63 }
64 
doFunctionsMatch(const FunctionDecl * Callee,const FunctionDecl * OtherCallee,unsigned int NumCoordParams,unsigned int SkipArgs,unsigned int MinArg,bool IsTripoint)65 static bool doFunctionsMatch( const FunctionDecl *Callee, const FunctionDecl *OtherCallee,
66                               unsigned int NumCoordParams, unsigned int SkipArgs,
67                               unsigned int MinArg, bool IsTripoint )
68 {
69     const unsigned int ExpectedNumParams = Callee->getNumParams() - ( NumCoordParams - 1 );
70 
71     if( OtherCallee->getNumParams() != ExpectedNumParams ) {
72         return false;
73     }
74     // Check that arguments match up as expected
75     unsigned int CalleeParamI = 0;
76     unsigned int OtherCalleeParamI = 0;
77 
78     for( ; CalleeParamI < Callee->getNumParams(); ++CalleeParamI, ++OtherCalleeParamI ) {
79         const ParmVarDecl *CalleeParam = Callee->getParamDecl( CalleeParamI );
80         const ParmVarDecl *OtherCalleeParam =
81             OtherCallee->getParamDecl( OtherCalleeParamI );
82 
83         if( CalleeParamI == MinArg - SkipArgs ) {
84             std::string ShortTypeName = IsTripoint ? "tripoint" : "point";
85             std::string ExpectedTypeName = "const struct " + ShortTypeName + " &";
86             if( OtherCalleeParam->getType().getAsString() != ExpectedTypeName ) {
87                 return false;
88             }
89             CalleeParamI += NumCoordParams - 1;
90         } else {
91             // Compare the types as strings because if e.g. the two overloads
92             // are function templates then the tmplate parameters will be
93             // different types.
94             if( CalleeParam->getType().getLocalUnqualifiedType().getAsString() !=
95                 OtherCalleeParam->getType().getLocalUnqualifiedType().getAsString() ) {
96                 return false;
97             }
98         }
99     }
100 
101     return true;
102 }
103 
CheckCall(UsePointApisCheck & Check,const MatchFinder::MatchResult & Result)104 static void CheckCall( UsePointApisCheck &Check, const MatchFinder::MatchResult &Result )
105 {
106     const ParmVarDecl *XParam = Result.Nodes.getNodeAs<ParmVarDecl>( "xparam" );
107     const Expr *XArg = Result.Nodes.getNodeAs<Expr>( "xarg" );
108     const CallExpr *Call = Result.Nodes.getNodeAs<CallExpr>( "call" );
109     const CXXConstructExpr *ConstructorCall =
110         Result.Nodes.getNodeAs<CXXConstructExpr>( "constructorCall" );
111     const FunctionDecl *Callee = Result.Nodes.getNodeAs<FunctionDecl>( "callee" );
112     if( !XParam || !XArg || !( Call || ConstructorCall ) || !Callee ) {
113         return;
114     }
115 
116     const Expr *YArg = nullptr;
117     const Expr *ZArg = nullptr;
118     unsigned int MinArg = UINT_MAX;
119     unsigned int MaxArg = 0;
120 
121     unsigned int NumCallArgs = Call ? Call->getNumArgs() : ConstructorCall->getNumArgs();
122     SourceLocation CallBeginLoc = Call ? Call->getBeginLoc() : ConstructorCall->getBeginLoc();
123     auto GetCallArg = [&]( unsigned int Arg ) {
124         return Call ? Call->getArg( Arg ) : ConstructorCall->getArg( Arg );
125     };
126 
127     // For operator() and operator= calls there is an extra 'this' argument that doesn't
128     // correspond to any parameter, so we need to skip over it.
129     unsigned int SkipArgs = 0;
130     if( Callee->getOverloadedOperator() == OO_Call ||
131         Callee->getOverloadedOperator() == OO_Subscript ||
132         Callee->getOverloadedOperator() == OO_Equal ) {
133         SkipArgs = 1;
134     }
135 
136     if( NumCallArgs - SkipArgs > Callee->getNumParams() ) {
137         Check.diag(
138             CallBeginLoc,
139             "Internal check error: call has more arguments (%0) than function has parameters (%1)"
140         ) << Call->getNumArgs() << Callee->getNumParams();
141         Check.diag( Callee->getLocation(), "called function %0", DiagnosticIDs::Note ) << Callee;
142         return;
143     }
144 
145     NameConvention NameMatcher( XParam->getName() );
146 
147     if( !NameMatcher ) {
148         return;
149     }
150 
151     for( unsigned int i = SkipArgs; i < NumCallArgs; ++i ) {
152         const ParmVarDecl *Param = Callee->getParamDecl( i - SkipArgs );
153         bool Matched = true;
154         switch( NameMatcher.Match( Param->getName() ) ) {
155             case NameConvention::XName:
156                 break;
157             case NameConvention::YName:
158                 YArg = GetCallArg( i );
159                 break;
160             case NameConvention::ZName:
161                 ZArg = GetCallArg( i );
162                 break;
163             default:
164                 Matched = false;
165         }
166 
167         if( Matched ) {
168             MinArg = std::min( MinArg, i );
169             MaxArg = std::max( MaxArg, i );
170         }
171     }
172 
173     if( !YArg ) {
174         return;
175     }
176 
177     const unsigned int NumCoordParams = ZArg ? 3 : 2;
178 
179     if( MaxArg - MinArg != NumCoordParams - 1 ) {
180         // This means that the parameters are not contiguous, which means we
181         // can't be sure we know what's going on.
182         return;
183     }
184 
185     const FunctionDecl *ContainingFunction = getContainingFunction(
186                 Result, Call ? static_cast<const Expr *>( Call ) : ConstructorCall );
187 
188     // Look for another overload of the called function with a point parameter
189     // in the right spot.
190 
191     const FunctionDecl *NewCallee = nullptr;
192     const DeclContext *Context = Callee->getDeclContext();
193     for( const NamedDecl *OtherDecl : Context->lookup( Callee->getDeclName() ) ) {
194         if( const FunctionDecl *OtherCallee = dyn_cast<FunctionDecl>( OtherDecl ) ) {
195             if( OtherCallee == Callee || OtherCallee == ContainingFunction ) {
196                 continue;
197             }
198 
199             if( doFunctionsMatch( Callee, OtherCallee, NumCoordParams, SkipArgs, MinArg,
200                                   !!ZArg ) ) {
201                 NewCallee = OtherCallee;
202                 break;
203             }
204         }
205         if( const FunctionTemplateDecl *OtherTmpl =
206                 dyn_cast<FunctionTemplateDecl>( OtherDecl ) ) {
207             const FunctionTemplateDecl *Tmpl = Callee->getPrimaryTemplate();
208 
209             if( !Tmpl || Tmpl == OtherTmpl ) {
210                 continue;
211             }
212 
213             if( doFunctionsMatch( Tmpl->getTemplatedDecl(), OtherTmpl->getTemplatedDecl(),
214                                   NumCoordParams, SkipArgs, MinArg, !!ZArg ) ) {
215                 NewCallee = OtherTmpl->getTemplatedDecl();
216                 break;
217             }
218         }
219     }
220 
221     if( !NewCallee ) {
222         // No new overload available; no replacement to suggest
223         return;
224     }
225 
226     // Construct replacement text
227     std::string Replacement =
228         ( "point( " + getText( Result, XArg ) + ", " + getText( Result, YArg ) ).str();
229     if( ZArg ) {
230         Replacement = ( "tri" + Replacement + ", " + getText( Result, ZArg ) ).str();
231     }
232     Replacement += " )";
233 
234     // Construct range to be replaced
235     while( isa<CXXDefaultArgExpr>( GetCallArg( MaxArg ) ) ) {
236         --MaxArg;
237         if( MaxArg == UINT_MAX ) {
238             // We underflowed; that means every argument was defaulted.  In
239             // this case, we don't want to change the call at all
240             return;
241         }
242     }
243     SourceRange SourceRangeToReplace( GetCallArg( MinArg )->getBeginLoc(),
244                                       GetCallArg( MaxArg )->getEndLoc() );
245     CharSourceRange CharRangeToReplace = Lexer::makeFileCharRange(
246             CharSourceRange::getTokenRange( SourceRangeToReplace ), *Result.SourceManager,
247             Check.getLangOpts() );
248 
249     std::string message =
250         ZArg ? "Call to %0 could instead call overload using a tripoint parameter."
251         : "Call to %0 could instead call overload using a point parameter.";
252 
253     Check.diag( CallBeginLoc, message )
254             << Callee << FixItHint::CreateReplacement( CharRangeToReplace, Replacement );
255     Check.diag( Callee->getLocation(), "current overload", DiagnosticIDs::Note );
256     Check.diag( NewCallee->getLocation(), "alternate overload", DiagnosticIDs::Note );
257 }
258 
check(const MatchFinder::MatchResult & Result)259 void UsePointApisCheck::check( const MatchFinder::MatchResult &Result )
260 {
261     CheckCall( *this, Result );
262 }
263 
264 } // namespace cata
265 } // namespace tidy
266 } // namespace clang
267