1 //===--- DurationFactoryScaleCheck.cpp - clang-tidy -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DurationFactoryScaleCheck.h"
10 #include "DurationRewriter.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/ASTMatchers/ASTMatchFinder.h"
13 #include "clang/Tooling/FixIt.h"
14 
15 using namespace clang::ast_matchers;
16 
17 namespace clang {
18 namespace tidy {
19 namespace abseil {
20 
21 // Given the name of a duration factory function, return the appropriate
22 // `DurationScale` for that factory.  If no factory can be found for
23 // `FactoryName`, return `None`.
24 static llvm::Optional<DurationScale>
getScaleForFactory(llvm::StringRef FactoryName)25 getScaleForFactory(llvm::StringRef FactoryName) {
26   return llvm::StringSwitch<llvm::Optional<DurationScale>>(FactoryName)
27       .Case("Nanoseconds", DurationScale::Nanoseconds)
28       .Case("Microseconds", DurationScale::Microseconds)
29       .Case("Milliseconds", DurationScale::Milliseconds)
30       .Case("Seconds", DurationScale::Seconds)
31       .Case("Minutes", DurationScale::Minutes)
32       .Case("Hours", DurationScale::Hours)
33       .Default(llvm::None);
34 }
35 
36 // Given either an integer or float literal, return its value.
37 // One and only one of `IntLit` and `FloatLit` should be provided.
getValue(const IntegerLiteral * IntLit,const FloatingLiteral * FloatLit)38 static double getValue(const IntegerLiteral *IntLit,
39                        const FloatingLiteral *FloatLit) {
40   if (IntLit)
41     return IntLit->getValue().getLimitedValue();
42 
43   assert(FloatLit != nullptr && "Neither IntLit nor FloatLit set");
44   return FloatLit->getValueAsApproximateDouble();
45 }
46 
47 // Given the scale of a duration and a `Multiplier`, determine if `Multiplier`
48 // would produce a new scale.  If so, return a tuple containing the new scale
49 // and a suitable Multiplier for that scale, otherwise `None`.
50 static llvm::Optional<std::tuple<DurationScale, double>>
getNewScaleSingleStep(DurationScale OldScale,double Multiplier)51 getNewScaleSingleStep(DurationScale OldScale, double Multiplier) {
52   switch (OldScale) {
53   case DurationScale::Hours:
54     if (Multiplier <= 1.0 / 60.0)
55       return std::make_tuple(DurationScale::Minutes, Multiplier * 60.0);
56     break;
57 
58   case DurationScale::Minutes:
59     if (Multiplier >= 60.0)
60       return std::make_tuple(DurationScale::Hours, Multiplier / 60.0);
61     if (Multiplier <= 1.0 / 60.0)
62       return std::make_tuple(DurationScale::Seconds, Multiplier * 60.0);
63     break;
64 
65   case DurationScale::Seconds:
66     if (Multiplier >= 60.0)
67       return std::make_tuple(DurationScale::Minutes, Multiplier / 60.0);
68     if (Multiplier <= 1e-3)
69       return std::make_tuple(DurationScale::Milliseconds, Multiplier * 1e3);
70     break;
71 
72   case DurationScale::Milliseconds:
73     if (Multiplier >= 1e3)
74       return std::make_tuple(DurationScale::Seconds, Multiplier / 1e3);
75     if (Multiplier <= 1e-3)
76       return std::make_tuple(DurationScale::Microseconds, Multiplier * 1e3);
77     break;
78 
79   case DurationScale::Microseconds:
80     if (Multiplier >= 1e3)
81       return std::make_tuple(DurationScale::Milliseconds, Multiplier / 1e3);
82     if (Multiplier <= 1e-3)
83       return std::make_tuple(DurationScale::Nanoseconds, Multiplier * 1e-3);
84     break;
85 
86   case DurationScale::Nanoseconds:
87     if (Multiplier >= 1e3)
88       return std::make_tuple(DurationScale::Microseconds, Multiplier / 1e3);
89     break;
90   }
91 
92   return llvm::None;
93 }
94 
95 // Given the scale of a duration and a `Multiplier`, determine if `Multiplier`
96 // would produce a new scale.  If so, return it, otherwise `None`.
getNewScale(DurationScale OldScale,double Multiplier)97 static llvm::Optional<DurationScale> getNewScale(DurationScale OldScale,
98                                                  double Multiplier) {
99   while (Multiplier != 1.0) {
100     llvm::Optional<std::tuple<DurationScale, double>> Result =
101         getNewScaleSingleStep(OldScale, Multiplier);
102     if (!Result)
103       break;
104     if (std::get<1>(*Result) == 1.0)
105       return std::get<0>(*Result);
106     Multiplier = std::get<1>(*Result);
107     OldScale = std::get<0>(*Result);
108   }
109 
110   return llvm::None;
111 }
112 
registerMatchers(MatchFinder * Finder)113 void DurationFactoryScaleCheck::registerMatchers(MatchFinder *Finder) {
114   Finder->addMatcher(
115       callExpr(
116           callee(functionDecl(DurationFactoryFunction()).bind("call_decl")),
117           hasArgument(
118               0,
119               ignoringImpCasts(anyOf(
120                   cxxFunctionalCastExpr(
121                       hasDestinationType(
122                           anyOf(isInteger(), realFloatingPointType())),
123                       hasSourceExpression(initListExpr())),
124                   integerLiteral(equals(0)), floatLiteral(equals(0.0)),
125                   binaryOperator(hasOperatorName("*"),
126                                  hasEitherOperand(ignoringImpCasts(
127                                      anyOf(integerLiteral(), floatLiteral()))))
128                       .bind("mult_binop"),
129                   binaryOperator(hasOperatorName("/"), hasRHS(floatLiteral()))
130                       .bind("div_binop")))))
131           .bind("call"),
132       this);
133 }
134 
check(const MatchFinder::MatchResult & Result)135 void DurationFactoryScaleCheck::check(const MatchFinder::MatchResult &Result) {
136   const auto *Call = Result.Nodes.getNodeAs<CallExpr>("call");
137 
138   // Don't try to replace things inside of macro definitions.
139   if (Call->getExprLoc().isMacroID())
140     return;
141 
142   const Expr *Arg = Call->getArg(0)->IgnoreParenImpCasts();
143   // Arguments which are macros are ignored.
144   if (Arg->getBeginLoc().isMacroID())
145     return;
146 
147   // We first handle the cases of literal zero (both float and integer).
148   if (IsLiteralZero(Result, *Arg)) {
149     diag(Call->getBeginLoc(),
150          "use ZeroDuration() for zero-length time intervals")
151         << FixItHint::CreateReplacement(Call->getSourceRange(),
152                                         "absl::ZeroDuration()");
153     return;
154   }
155 
156   const auto *CallDecl = Result.Nodes.getNodeAs<FunctionDecl>("call_decl");
157   llvm::Optional<DurationScale> MaybeScale =
158       getScaleForFactory(CallDecl->getName());
159   if (!MaybeScale)
160     return;
161 
162   DurationScale Scale = *MaybeScale;
163   const Expr *Remainder;
164   llvm::Optional<DurationScale> NewScale;
165 
166   // We next handle the cases of multiplication and division.
167   if (const auto *MultBinOp =
168           Result.Nodes.getNodeAs<BinaryOperator>("mult_binop")) {
169     // For multiplication, we need to look at both operands, and consider the
170     // cases where a user is multiplying by something such as 1e-3.
171 
172     // First check the LHS
173     const auto *IntLit = llvm::dyn_cast<IntegerLiteral>(MultBinOp->getLHS());
174     const auto *FloatLit = llvm::dyn_cast<FloatingLiteral>(MultBinOp->getLHS());
175     if (IntLit || FloatLit) {
176       NewScale = getNewScale(Scale, getValue(IntLit, FloatLit));
177       if (NewScale)
178         Remainder = MultBinOp->getRHS();
179     }
180 
181     // If we weren't able to scale based on the LHS, check the RHS
182     if (!NewScale) {
183       IntLit = llvm::dyn_cast<IntegerLiteral>(MultBinOp->getRHS());
184       FloatLit = llvm::dyn_cast<FloatingLiteral>(MultBinOp->getRHS());
185       if (IntLit || FloatLit) {
186         NewScale = getNewScale(Scale, getValue(IntLit, FloatLit));
187         if (NewScale)
188           Remainder = MultBinOp->getLHS();
189       }
190     }
191   } else if (const auto *DivBinOp =
192                  Result.Nodes.getNodeAs<BinaryOperator>("div_binop")) {
193     // We next handle division.
194     // For division, we only check the RHS.
195     const auto *FloatLit = llvm::dyn_cast<FloatingLiteral>(DivBinOp->getRHS());
196 
197     llvm::Optional<DurationScale> NewScale =
198         getNewScale(Scale, 1.0 / FloatLit->getValueAsApproximateDouble());
199     if (NewScale) {
200       const Expr *Remainder = DivBinOp->getLHS();
201 
202       // We've found an appropriate scaling factor and the new scale, so output
203       // the relevant fix.
204       diag(Call->getBeginLoc(), "internal duration scaling can be removed")
205           << FixItHint::CreateReplacement(
206                  Call->getSourceRange(),
207                  (llvm::Twine(getDurationFactoryForScale(*NewScale)) + "(" +
208                   tooling::fixit::getText(*Remainder, *Result.Context) + ")")
209                      .str());
210     }
211   }
212 
213   if (NewScale) {
214     assert(Remainder && "No remainder found");
215     // We've found an appropriate scaling factor and the new scale, so output
216     // the relevant fix.
217     diag(Call->getBeginLoc(), "internal duration scaling can be removed")
218         << FixItHint::CreateReplacement(
219                Call->getSourceRange(),
220                (llvm::Twine(getDurationFactoryForScale(*NewScale)) + "(" +
221                 tooling::fixit::getText(*Remainder, *Result.Context) + ")")
222                    .str());
223   }
224   return;
225 }
226 
227 } // namespace abseil
228 } // namespace tidy
229 } // namespace clang
230