1 //===--- TypeMismatchCheck.cpp - clang-tidy--------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TypeMismatchCheck.h"
10 #include "clang/Lex/Lexer.h"
11 #include "clang/Tooling/FixIt.h"
12 #include <map>
13 #include <unordered_set>
14 
15 using namespace clang::ast_matchers;
16 
17 namespace clang {
18 namespace tidy {
19 namespace mpi {
20 
21 /// Check if a BuiltinType::Kind matches the MPI datatype.
22 ///
23 /// \param MultiMap datatype group
24 /// \param Kind buffer type kind
25 /// \param MPIDatatype name of the MPI datatype
26 ///
27 /// \returns true if the pair matches
28 static bool
isMPITypeMatching(const std::multimap<BuiltinType::Kind,std::string> & MultiMap,const BuiltinType::Kind Kind,const std::string & MPIDatatype)29 isMPITypeMatching(const std::multimap<BuiltinType::Kind, std::string> &MultiMap,
30                   const BuiltinType::Kind Kind,
31                   const std::string &MPIDatatype) {
32   auto ItPair = MultiMap.equal_range(Kind);
33   while (ItPair.first != ItPair.second) {
34     if (ItPair.first->second == MPIDatatype)
35       return true;
36     ++ItPair.first;
37   }
38   return false;
39 }
40 
41 /// Check if the MPI datatype is a standard type.
42 ///
43 /// \param MPIDatatype name of the MPI datatype
44 ///
45 /// \returns true if the type is a standard type
isStandardMPIDatatype(const std::string & MPIDatatype)46 static bool isStandardMPIDatatype(const std::string &MPIDatatype) {
47   static std::unordered_set<std::string> AllTypes = {
48       "MPI_C_BOOL",
49       "MPI_CHAR",
50       "MPI_SIGNED_CHAR",
51       "MPI_UNSIGNED_CHAR",
52       "MPI_WCHAR",
53       "MPI_INT",
54       "MPI_LONG",
55       "MPI_SHORT",
56       "MPI_LONG_LONG",
57       "MPI_LONG_LONG_INT",
58       "MPI_UNSIGNED",
59       "MPI_UNSIGNED_SHORT",
60       "MPI_UNSIGNED_LONG",
61       "MPI_UNSIGNED_LONG_LONG",
62       "MPI_FLOAT",
63       "MPI_DOUBLE",
64       "MPI_LONG_DOUBLE",
65       "MPI_C_COMPLEX",
66       "MPI_C_FLOAT_COMPLEX",
67       "MPI_C_DOUBLE_COMPLEX",
68       "MPI_C_LONG_DOUBLE_COMPLEX",
69       "MPI_INT8_T",
70       "MPI_INT16_T",
71       "MPI_INT32_T",
72       "MPI_INT64_T",
73       "MPI_UINT8_T",
74       "MPI_UINT16_T",
75       "MPI_UINT32_T",
76       "MPI_UINT64_T",
77       "MPI_CXX_BOOL",
78       "MPI_CXX_FLOAT_COMPLEX",
79       "MPI_CXX_DOUBLE_COMPLEX",
80       "MPI_CXX_LONG_DOUBLE_COMPLEX"};
81 
82   return AllTypes.find(MPIDatatype) != AllTypes.end();
83 }
84 
85 /// Check if a BuiltinType matches the MPI datatype.
86 ///
87 /// \param Builtin the builtin type
88 /// \param BufferTypeName buffer type name, gets assigned
89 /// \param MPIDatatype name of the MPI datatype
90 /// \param LO language options
91 ///
92 /// \returns true if the type matches
isBuiltinTypeMatching(const BuiltinType * Builtin,std::string & BufferTypeName,const std::string & MPIDatatype,const LangOptions & LO)93 static bool isBuiltinTypeMatching(const BuiltinType *Builtin,
94                                   std::string &BufferTypeName,
95                                   const std::string &MPIDatatype,
96                                   const LangOptions &LO) {
97   static std::multimap<BuiltinType::Kind, std::string> BuiltinMatches = {
98       // On some systems like PPC or ARM, 'char' is unsigned by default which is
99       // why distinct signedness for the buffer and MPI type is tolerated.
100       {BuiltinType::SChar, "MPI_CHAR"},
101       {BuiltinType::SChar, "MPI_SIGNED_CHAR"},
102       {BuiltinType::SChar, "MPI_UNSIGNED_CHAR"},
103       {BuiltinType::Char_S, "MPI_CHAR"},
104       {BuiltinType::Char_S, "MPI_SIGNED_CHAR"},
105       {BuiltinType::Char_S, "MPI_UNSIGNED_CHAR"},
106       {BuiltinType::UChar, "MPI_CHAR"},
107       {BuiltinType::UChar, "MPI_SIGNED_CHAR"},
108       {BuiltinType::UChar, "MPI_UNSIGNED_CHAR"},
109       {BuiltinType::Char_U, "MPI_CHAR"},
110       {BuiltinType::Char_U, "MPI_SIGNED_CHAR"},
111       {BuiltinType::Char_U, "MPI_UNSIGNED_CHAR"},
112       {BuiltinType::WChar_S, "MPI_WCHAR"},
113       {BuiltinType::WChar_U, "MPI_WCHAR"},
114       {BuiltinType::Bool, "MPI_C_BOOL"},
115       {BuiltinType::Bool, "MPI_CXX_BOOL"},
116       {BuiltinType::Short, "MPI_SHORT"},
117       {BuiltinType::Int, "MPI_INT"},
118       {BuiltinType::Long, "MPI_LONG"},
119       {BuiltinType::LongLong, "MPI_LONG_LONG"},
120       {BuiltinType::LongLong, "MPI_LONG_LONG_INT"},
121       {BuiltinType::UShort, "MPI_UNSIGNED_SHORT"},
122       {BuiltinType::UInt, "MPI_UNSIGNED"},
123       {BuiltinType::ULong, "MPI_UNSIGNED_LONG"},
124       {BuiltinType::ULongLong, "MPI_UNSIGNED_LONG_LONG"},
125       {BuiltinType::Float, "MPI_FLOAT"},
126       {BuiltinType::Double, "MPI_DOUBLE"},
127       {BuiltinType::LongDouble, "MPI_LONG_DOUBLE"}};
128 
129   if (!isMPITypeMatching(BuiltinMatches, Builtin->getKind(), MPIDatatype)) {
130     BufferTypeName = std::string(Builtin->getName(LO));
131     return false;
132   }
133 
134   return true;
135 }
136 
137 /// Check if a complex float/double/long double buffer type matches
138 /// the MPI datatype.
139 ///
140 /// \param Complex buffer type
141 /// \param BufferTypeName buffer type name, gets assigned
142 /// \param MPIDatatype name of the MPI datatype
143 /// \param LO language options
144 ///
145 /// \returns true if the type matches or the buffer type is unknown
isCComplexTypeMatching(const ComplexType * const Complex,std::string & BufferTypeName,const std::string & MPIDatatype,const LangOptions & LO)146 static bool isCComplexTypeMatching(const ComplexType *const Complex,
147                                    std::string &BufferTypeName,
148                                    const std::string &MPIDatatype,
149                                    const LangOptions &LO) {
150   static std::multimap<BuiltinType::Kind, std::string> ComplexCMatches = {
151       {BuiltinType::Float, "MPI_C_COMPLEX"},
152       {BuiltinType::Float, "MPI_C_FLOAT_COMPLEX"},
153       {BuiltinType::Double, "MPI_C_DOUBLE_COMPLEX"},
154       {BuiltinType::LongDouble, "MPI_C_LONG_DOUBLE_COMPLEX"}};
155 
156   const auto *Builtin =
157       Complex->getElementType().getTypePtr()->getAs<BuiltinType>();
158 
159   if (Builtin &&
160       !isMPITypeMatching(ComplexCMatches, Builtin->getKind(), MPIDatatype)) {
161     BufferTypeName = (llvm::Twine(Builtin->getName(LO)) + " _Complex").str();
162     return false;
163   }
164   return true;
165 }
166 
167 /// Check if a complex<float/double/long double> templated buffer type matches
168 /// the MPI datatype.
169 ///
170 /// \param Template buffer type
171 /// \param BufferTypeName buffer type name, gets assigned
172 /// \param MPIDatatype name of the MPI datatype
173 /// \param LO language options
174 ///
175 /// \returns true if the type matches or the buffer type is unknown
176 static bool
isCXXComplexTypeMatching(const TemplateSpecializationType * const Template,std::string & BufferTypeName,const std::string & MPIDatatype,const LangOptions & LO)177 isCXXComplexTypeMatching(const TemplateSpecializationType *const Template,
178                          std::string &BufferTypeName,
179                          const std::string &MPIDatatype,
180                          const LangOptions &LO) {
181   static std::multimap<BuiltinType::Kind, std::string> ComplexCXXMatches = {
182       {BuiltinType::Float, "MPI_CXX_FLOAT_COMPLEX"},
183       {BuiltinType::Double, "MPI_CXX_DOUBLE_COMPLEX"},
184       {BuiltinType::LongDouble, "MPI_CXX_LONG_DOUBLE_COMPLEX"}};
185 
186   if (Template->getAsCXXRecordDecl()->getName() != "complex")
187     return true;
188 
189   const auto *Builtin =
190       Template->getArg(0).getAsType().getTypePtr()->getAs<BuiltinType>();
191 
192   if (Builtin &&
193       !isMPITypeMatching(ComplexCXXMatches, Builtin->getKind(), MPIDatatype)) {
194     BufferTypeName =
195         (llvm::Twine("complex<") + Builtin->getName(LO) + ">").str();
196     return false;
197   }
198 
199   return true;
200 }
201 
202 /// Check if a fixed size width buffer type matches the MPI datatype.
203 ///
204 /// \param Typedef buffer type
205 /// \param BufferTypeName buffer type name, gets assigned
206 /// \param MPIDatatype name of the MPI datatype
207 ///
208 /// \returns true if the type matches or the buffer type is unknown
isTypedefTypeMatching(const TypedefType * const Typedef,std::string & BufferTypeName,const std::string & MPIDatatype)209 static bool isTypedefTypeMatching(const TypedefType *const Typedef,
210                                   std::string &BufferTypeName,
211                                   const std::string &MPIDatatype) {
212   static llvm::StringMap<std::string> FixedWidthMatches = {
213       {"int8_t", "MPI_INT8_T"},     {"int16_t", "MPI_INT16_T"},
214       {"int32_t", "MPI_INT32_T"},   {"int64_t", "MPI_INT64_T"},
215       {"uint8_t", "MPI_UINT8_T"},   {"uint16_t", "MPI_UINT16_T"},
216       {"uint32_t", "MPI_UINT32_T"}, {"uint64_t", "MPI_UINT64_T"}};
217 
218   const auto It = FixedWidthMatches.find(Typedef->getDecl()->getName());
219   // Check if the typedef is known and not matching the MPI datatype.
220   if (It != FixedWidthMatches.end() && It->getValue() != MPIDatatype) {
221     BufferTypeName = std::string(Typedef->getDecl()->getName());
222     return false;
223   }
224   return true;
225 }
226 
227 /// Get the unqualified, dereferenced type of an argument.
228 ///
229 /// \param CE call expression
230 /// \param Idx argument index
231 ///
232 /// \returns type of the argument
argumentType(const CallExpr * const CE,const size_t Idx)233 static const Type *argumentType(const CallExpr *const CE, const size_t Idx) {
234   const QualType QT = CE->getArg(Idx)->IgnoreImpCasts()->getType();
235   return QT.getTypePtr()->getPointeeOrArrayElementType();
236 }
237 
registerMatchers(MatchFinder * Finder)238 void TypeMismatchCheck::registerMatchers(MatchFinder *Finder) {
239   Finder->addMatcher(callExpr().bind("CE"), this);
240 }
241 
check(const MatchFinder::MatchResult & Result)242 void TypeMismatchCheck::check(const MatchFinder::MatchResult &Result) {
243   const auto *const CE = Result.Nodes.getNodeAs<CallExpr>("CE");
244   if (!CE->getDirectCallee())
245     return;
246 
247   if (!FuncClassifier)
248     FuncClassifier.emplace(*Result.Context);
249 
250   const IdentifierInfo *Identifier = CE->getDirectCallee()->getIdentifier();
251   if (!Identifier || !FuncClassifier->isMPIType(Identifier))
252     return;
253 
254   // These containers are used, to capture buffer, MPI datatype pairs.
255   SmallVector<const Type *, 1> BufferTypes;
256   SmallVector<const Expr *, 1> BufferExprs;
257   SmallVector<StringRef, 1> MPIDatatypes;
258 
259   // Adds a buffer, MPI datatype pair of an MPI call expression to the
260   // containers. For buffers, the type and expression is captured.
261   auto AddPair = [&CE, &Result, &BufferTypes, &BufferExprs, &MPIDatatypes](
262                      const size_t BufferIdx, const size_t DatatypeIdx) {
263     // Skip null pointer constants and in place 'operators'.
264     if (CE->getArg(BufferIdx)->isNullPointerConstant(
265             *Result.Context, Expr::NPC_ValueDependentIsNull) ||
266         tooling::fixit::getText(*CE->getArg(BufferIdx), *Result.Context) ==
267             "MPI_IN_PLACE")
268       return;
269 
270     StringRef MPIDatatype =
271         tooling::fixit::getText(*CE->getArg(DatatypeIdx), *Result.Context);
272 
273     const Type *ArgType = argumentType(CE, BufferIdx);
274     // Skip unknown MPI datatypes and void pointers.
275     if (!isStandardMPIDatatype(std::string(MPIDatatype)) ||
276         ArgType->isVoidType())
277       return;
278 
279     BufferTypes.push_back(ArgType);
280     BufferExprs.push_back(CE->getArg(BufferIdx));
281     MPIDatatypes.push_back(MPIDatatype);
282   };
283 
284   // Collect all buffer, MPI datatype pairs for the inspected call expression.
285   if (FuncClassifier->isPointToPointType(Identifier)) {
286     AddPair(0, 2);
287   } else if (FuncClassifier->isCollectiveType(Identifier)) {
288     if (FuncClassifier->isReduceType(Identifier)) {
289       AddPair(0, 3);
290       AddPair(1, 3);
291     } else if (FuncClassifier->isScatterType(Identifier) ||
292                FuncClassifier->isGatherType(Identifier) ||
293                FuncClassifier->isAlltoallType(Identifier)) {
294       AddPair(0, 2);
295       AddPair(3, 5);
296     } else if (FuncClassifier->isBcastType(Identifier)) {
297       AddPair(0, 2);
298     }
299   }
300   checkArguments(BufferTypes, BufferExprs, MPIDatatypes, getLangOpts());
301 }
302 
checkArguments(ArrayRef<const Type * > BufferTypes,ArrayRef<const Expr * > BufferExprs,ArrayRef<StringRef> MPIDatatypes,const LangOptions & LO)303 void TypeMismatchCheck::checkArguments(ArrayRef<const Type *> BufferTypes,
304                                        ArrayRef<const Expr *> BufferExprs,
305                                        ArrayRef<StringRef> MPIDatatypes,
306                                        const LangOptions &LO) {
307   std::string BufferTypeName;
308 
309   for (size_t I = 0; I < MPIDatatypes.size(); ++I) {
310     const Type *const BT = BufferTypes[I];
311     bool Error = false;
312 
313     if (const auto *Typedef = BT->getAs<TypedefType>()) {
314       Error = !isTypedefTypeMatching(Typedef, BufferTypeName,
315                                      std::string(MPIDatatypes[I]));
316     } else if (const auto *Complex = BT->getAs<ComplexType>()) {
317       Error = !isCComplexTypeMatching(Complex, BufferTypeName,
318                                       std::string(MPIDatatypes[I]), LO);
319     } else if (const auto *Template = BT->getAs<TemplateSpecializationType>()) {
320       Error = !isCXXComplexTypeMatching(Template, BufferTypeName,
321                                         std::string(MPIDatatypes[I]), LO);
322     } else if (const auto *Builtin = BT->getAs<BuiltinType>()) {
323       Error = !isBuiltinTypeMatching(Builtin, BufferTypeName,
324                                      std::string(MPIDatatypes[I]), LO);
325     }
326 
327     if (Error) {
328       const auto Loc = BufferExprs[I]->getSourceRange().getBegin();
329       diag(Loc, "buffer type '%0' does not match the MPI datatype '%1'")
330           << BufferTypeName << MPIDatatypes[I];
331     }
332   }
333 }
334 
onEndOfTranslationUnit()335 void TypeMismatchCheck::onEndOfTranslationUnit() { FuncClassifier.reset(); }
336 } // namespace mpi
337 } // namespace tidy
338 } // namespace clang
339