1 //===--- DurationFactoryScaleCheck.cpp - clang-tidy -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DurationFactoryScaleCheck.h"
10 #include "DurationRewriter.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/ASTMatchers/ASTMatchFinder.h"
13 #include "clang/Tooling/FixIt.h"
14 
15 using namespace clang::ast_matchers;
16 
17 namespace clang {
18 namespace tidy {
19 namespace abseil {
20 
21 // Given the name of a duration factory function, return the appropriate
22 // `DurationScale` for that factory.  If no factory can be found for
23 // `FactoryName`, return `None`.
24 static llvm::Optional<DurationScale>
getScaleForFactory(llvm::StringRef FactoryName)25 getScaleForFactory(llvm::StringRef FactoryName) {
26   static const std::unordered_map<std::string, DurationScale> ScaleMap(
27       {{"Nanoseconds", DurationScale::Nanoseconds},
28        {"Microseconds", DurationScale::Microseconds},
29        {"Milliseconds", DurationScale::Milliseconds},
30        {"Seconds", DurationScale::Seconds},
31        {"Minutes", DurationScale::Minutes},
32        {"Hours", DurationScale::Hours}});
33 
34   auto ScaleIter = ScaleMap.find(FactoryName);
35   if (ScaleIter == ScaleMap.end())
36     return llvm::None;
37 
38   return ScaleIter->second;
39 }
40 
41 // Given either an integer or float literal, return its value.
42 // One and only one of `IntLit` and `FloatLit` should be provided.
GetValue(const IntegerLiteral * IntLit,const FloatingLiteral * FloatLit)43 static double GetValue(const IntegerLiteral *IntLit,
44                        const FloatingLiteral *FloatLit) {
45   if (IntLit)
46     return IntLit->getValue().getLimitedValue();
47 
48   assert(FloatLit != nullptr && "Neither IntLit nor FloatLit set");
49   return FloatLit->getValueAsApproximateDouble();
50 }
51 
52 // Given the scale of a duration and a `Multiplier`, determine if `Multiplier`
53 // would produce a new scale.  If so, return a tuple containing the new scale
54 // and a suitable Multipler for that scale, otherwise `None`.
55 static llvm::Optional<std::tuple<DurationScale, double>>
GetNewScaleSingleStep(DurationScale OldScale,double Multiplier)56 GetNewScaleSingleStep(DurationScale OldScale, double Multiplier) {
57   switch (OldScale) {
58   case DurationScale::Hours:
59     if (Multiplier <= 1.0 / 60.0)
60       return std::make_tuple(DurationScale::Minutes, Multiplier * 60.0);
61     break;
62 
63   case DurationScale::Minutes:
64     if (Multiplier >= 60.0)
65       return std::make_tuple(DurationScale::Hours, Multiplier / 60.0);
66     if (Multiplier <= 1.0 / 60.0)
67       return std::make_tuple(DurationScale::Seconds, Multiplier * 60.0);
68     break;
69 
70   case DurationScale::Seconds:
71     if (Multiplier >= 60.0)
72       return std::make_tuple(DurationScale::Minutes, Multiplier / 60.0);
73     if (Multiplier <= 1e-3)
74       return std::make_tuple(DurationScale::Milliseconds, Multiplier * 1e3);
75     break;
76 
77   case DurationScale::Milliseconds:
78     if (Multiplier >= 1e3)
79       return std::make_tuple(DurationScale::Seconds, Multiplier / 1e3);
80     if (Multiplier <= 1e-3)
81       return std::make_tuple(DurationScale::Microseconds, Multiplier * 1e3);
82     break;
83 
84   case DurationScale::Microseconds:
85     if (Multiplier >= 1e3)
86       return std::make_tuple(DurationScale::Milliseconds, Multiplier / 1e3);
87     if (Multiplier <= 1e-3)
88       return std::make_tuple(DurationScale::Nanoseconds, Multiplier * 1e-3);
89     break;
90 
91   case DurationScale::Nanoseconds:
92     if (Multiplier >= 1e3)
93       return std::make_tuple(DurationScale::Microseconds, Multiplier / 1e3);
94     break;
95   }
96 
97   return llvm::None;
98 }
99 
100 // Given the scale of a duration and a `Multiplier`, determine if `Multiplier`
101 // would produce a new scale.  If so, return it, otherwise `None`.
GetNewScale(DurationScale OldScale,double Multiplier)102 static llvm::Optional<DurationScale> GetNewScale(DurationScale OldScale,
103                                                  double Multiplier) {
104   while (Multiplier != 1.0) {
105     llvm::Optional<std::tuple<DurationScale, double>> result =
106         GetNewScaleSingleStep(OldScale, Multiplier);
107     if (!result)
108       break;
109     if (std::get<1>(*result) == 1.0)
110       return std::get<0>(*result);
111     Multiplier = std::get<1>(*result);
112     OldScale = std::get<0>(*result);
113   }
114 
115   return llvm::None;
116 }
117 
registerMatchers(MatchFinder * Finder)118 void DurationFactoryScaleCheck::registerMatchers(MatchFinder *Finder) {
119   Finder->addMatcher(
120       callExpr(
121           callee(functionDecl(DurationFactoryFunction()).bind("call_decl")),
122           hasArgument(
123               0,
124               ignoringImpCasts(anyOf(
125                   cxxFunctionalCastExpr(
126                       hasDestinationType(
127                           anyOf(isInteger(), realFloatingPointType())),
128                       hasSourceExpression(initListExpr())),
129                   integerLiteral(equals(0)), floatLiteral(equals(0.0)),
130                   binaryOperator(hasOperatorName("*"),
131                                  hasEitherOperand(ignoringImpCasts(
132                                      anyOf(integerLiteral(), floatLiteral()))))
133                       .bind("mult_binop"),
134                   binaryOperator(hasOperatorName("/"), hasRHS(floatLiteral()))
135                       .bind("div_binop")))))
136           .bind("call"),
137       this);
138 }
139 
check(const MatchFinder::MatchResult & Result)140 void DurationFactoryScaleCheck::check(const MatchFinder::MatchResult &Result) {
141   const auto *Call = Result.Nodes.getNodeAs<CallExpr>("call");
142 
143   // Don't try to replace things inside of macro definitions.
144   if (Call->getExprLoc().isMacroID())
145     return;
146 
147   const Expr *Arg = Call->getArg(0)->IgnoreParenImpCasts();
148   // Arguments which are macros are ignored.
149   if (Arg->getBeginLoc().isMacroID())
150     return;
151 
152   // We first handle the cases of literal zero (both float and integer).
153   if (IsLiteralZero(Result, *Arg)) {
154     diag(Call->getBeginLoc(),
155          "use ZeroDuration() for zero-length time intervals")
156         << FixItHint::CreateReplacement(Call->getSourceRange(),
157                                         "absl::ZeroDuration()");
158     return;
159   }
160 
161   const auto *CallDecl = Result.Nodes.getNodeAs<FunctionDecl>("call_decl");
162   llvm::Optional<DurationScale> MaybeScale =
163       getScaleForFactory(CallDecl->getName());
164   if (!MaybeScale)
165     return;
166 
167   DurationScale Scale = *MaybeScale;
168   const Expr *Remainder;
169   llvm::Optional<DurationScale> NewScale;
170 
171   // We next handle the cases of multiplication and division.
172   if (const auto *MultBinOp =
173           Result.Nodes.getNodeAs<BinaryOperator>("mult_binop")) {
174     // For multiplication, we need to look at both operands, and consider the
175     // cases where a user is multiplying by something such as 1e-3.
176 
177     // First check the LHS
178     const auto *IntLit = llvm::dyn_cast<IntegerLiteral>(MultBinOp->getLHS());
179     const auto *FloatLit = llvm::dyn_cast<FloatingLiteral>(MultBinOp->getLHS());
180     if (IntLit || FloatLit) {
181       NewScale = GetNewScale(Scale, GetValue(IntLit, FloatLit));
182       if (NewScale)
183         Remainder = MultBinOp->getRHS();
184     }
185 
186     // If we weren't able to scale based on the LHS, check the RHS
187     if (!NewScale) {
188       IntLit = llvm::dyn_cast<IntegerLiteral>(MultBinOp->getRHS());
189       FloatLit = llvm::dyn_cast<FloatingLiteral>(MultBinOp->getRHS());
190       if (IntLit || FloatLit) {
191         NewScale = GetNewScale(Scale, GetValue(IntLit, FloatLit));
192         if (NewScale)
193           Remainder = MultBinOp->getLHS();
194       }
195     }
196   } else if (const auto *DivBinOp =
197                  Result.Nodes.getNodeAs<BinaryOperator>("div_binop")) {
198     // We next handle division.
199     // For division, we only check the RHS.
200     const auto *FloatLit = llvm::dyn_cast<FloatingLiteral>(DivBinOp->getRHS());
201 
202     llvm::Optional<DurationScale> NewScale =
203         GetNewScale(Scale, 1.0 / FloatLit->getValueAsApproximateDouble());
204     if (NewScale) {
205       const Expr *Remainder = DivBinOp->getLHS();
206 
207       // We've found an appropriate scaling factor and the new scale, so output
208       // the relevant fix.
209       diag(Call->getBeginLoc(), "internal duration scaling can be removed")
210           << FixItHint::CreateReplacement(
211                  Call->getSourceRange(),
212                  (llvm::Twine(getDurationFactoryForScale(*NewScale)) + "(" +
213                   tooling::fixit::getText(*Remainder, *Result.Context) + ")")
214                      .str());
215     }
216   }
217 
218   if (NewScale) {
219     assert(Remainder && "No remainder found");
220     // We've found an appropriate scaling factor and the new scale, so output
221     // the relevant fix.
222     diag(Call->getBeginLoc(), "internal duration scaling can be removed")
223         << FixItHint::CreateReplacement(
224                Call->getSourceRange(),
225                (llvm::Twine(getDurationFactoryForScale(*NewScale)) + "(" +
226                 tooling::fixit::getText(*Remainder, *Result.Context) + ")")
227                    .str());
228   }
229   return;
230 }
231 
232 } // namespace abseil
233 } // namespace tidy
234 } // namespace clang
235