1 //===--- DurationFactoryScaleCheck.cpp - clang-tidy -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "DurationFactoryScaleCheck.h"
10 #include "DurationRewriter.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/ASTMatchers/ASTMatchFinder.h"
13 #include "clang/Tooling/FixIt.h"
14
15 using namespace clang::ast_matchers;
16
17 namespace clang {
18 namespace tidy {
19 namespace abseil {
20
21 // Given the name of a duration factory function, return the appropriate
22 // `DurationScale` for that factory. If no factory can be found for
23 // `FactoryName`, return `None`.
24 static llvm::Optional<DurationScale>
getScaleForFactory(llvm::StringRef FactoryName)25 getScaleForFactory(llvm::StringRef FactoryName) {
26 static const std::unordered_map<std::string, DurationScale> ScaleMap(
27 {{"Nanoseconds", DurationScale::Nanoseconds},
28 {"Microseconds", DurationScale::Microseconds},
29 {"Milliseconds", DurationScale::Milliseconds},
30 {"Seconds", DurationScale::Seconds},
31 {"Minutes", DurationScale::Minutes},
32 {"Hours", DurationScale::Hours}});
33
34 auto ScaleIter = ScaleMap.find(FactoryName);
35 if (ScaleIter == ScaleMap.end())
36 return llvm::None;
37
38 return ScaleIter->second;
39 }
40
41 // Given either an integer or float literal, return its value.
42 // One and only one of `IntLit` and `FloatLit` should be provided.
GetValue(const IntegerLiteral * IntLit,const FloatingLiteral * FloatLit)43 static double GetValue(const IntegerLiteral *IntLit,
44 const FloatingLiteral *FloatLit) {
45 if (IntLit)
46 return IntLit->getValue().getLimitedValue();
47
48 assert(FloatLit != nullptr && "Neither IntLit nor FloatLit set");
49 return FloatLit->getValueAsApproximateDouble();
50 }
51
52 // Given the scale of a duration and a `Multiplier`, determine if `Multiplier`
53 // would produce a new scale. If so, return a tuple containing the new scale
54 // and a suitable Multipler for that scale, otherwise `None`.
55 static llvm::Optional<std::tuple<DurationScale, double>>
GetNewScaleSingleStep(DurationScale OldScale,double Multiplier)56 GetNewScaleSingleStep(DurationScale OldScale, double Multiplier) {
57 switch (OldScale) {
58 case DurationScale::Hours:
59 if (Multiplier <= 1.0 / 60.0)
60 return std::make_tuple(DurationScale::Minutes, Multiplier * 60.0);
61 break;
62
63 case DurationScale::Minutes:
64 if (Multiplier >= 60.0)
65 return std::make_tuple(DurationScale::Hours, Multiplier / 60.0);
66 if (Multiplier <= 1.0 / 60.0)
67 return std::make_tuple(DurationScale::Seconds, Multiplier * 60.0);
68 break;
69
70 case DurationScale::Seconds:
71 if (Multiplier >= 60.0)
72 return std::make_tuple(DurationScale::Minutes, Multiplier / 60.0);
73 if (Multiplier <= 1e-3)
74 return std::make_tuple(DurationScale::Milliseconds, Multiplier * 1e3);
75 break;
76
77 case DurationScale::Milliseconds:
78 if (Multiplier >= 1e3)
79 return std::make_tuple(DurationScale::Seconds, Multiplier / 1e3);
80 if (Multiplier <= 1e-3)
81 return std::make_tuple(DurationScale::Microseconds, Multiplier * 1e3);
82 break;
83
84 case DurationScale::Microseconds:
85 if (Multiplier >= 1e3)
86 return std::make_tuple(DurationScale::Milliseconds, Multiplier / 1e3);
87 if (Multiplier <= 1e-3)
88 return std::make_tuple(DurationScale::Nanoseconds, Multiplier * 1e-3);
89 break;
90
91 case DurationScale::Nanoseconds:
92 if (Multiplier >= 1e3)
93 return std::make_tuple(DurationScale::Microseconds, Multiplier / 1e3);
94 break;
95 }
96
97 return llvm::None;
98 }
99
100 // Given the scale of a duration and a `Multiplier`, determine if `Multiplier`
101 // would produce a new scale. If so, return it, otherwise `None`.
GetNewScale(DurationScale OldScale,double Multiplier)102 static llvm::Optional<DurationScale> GetNewScale(DurationScale OldScale,
103 double Multiplier) {
104 while (Multiplier != 1.0) {
105 llvm::Optional<std::tuple<DurationScale, double>> result =
106 GetNewScaleSingleStep(OldScale, Multiplier);
107 if (!result)
108 break;
109 if (std::get<1>(*result) == 1.0)
110 return std::get<0>(*result);
111 Multiplier = std::get<1>(*result);
112 OldScale = std::get<0>(*result);
113 }
114
115 return llvm::None;
116 }
117
registerMatchers(MatchFinder * Finder)118 void DurationFactoryScaleCheck::registerMatchers(MatchFinder *Finder) {
119 Finder->addMatcher(
120 callExpr(
121 callee(functionDecl(DurationFactoryFunction()).bind("call_decl")),
122 hasArgument(
123 0,
124 ignoringImpCasts(anyOf(
125 cxxFunctionalCastExpr(
126 hasDestinationType(
127 anyOf(isInteger(), realFloatingPointType())),
128 hasSourceExpression(initListExpr())),
129 integerLiteral(equals(0)), floatLiteral(equals(0.0)),
130 binaryOperator(hasOperatorName("*"),
131 hasEitherOperand(ignoringImpCasts(
132 anyOf(integerLiteral(), floatLiteral()))))
133 .bind("mult_binop"),
134 binaryOperator(hasOperatorName("/"), hasRHS(floatLiteral()))
135 .bind("div_binop")))))
136 .bind("call"),
137 this);
138 }
139
check(const MatchFinder::MatchResult & Result)140 void DurationFactoryScaleCheck::check(const MatchFinder::MatchResult &Result) {
141 const auto *Call = Result.Nodes.getNodeAs<CallExpr>("call");
142
143 // Don't try to replace things inside of macro definitions.
144 if (Call->getExprLoc().isMacroID())
145 return;
146
147 const Expr *Arg = Call->getArg(0)->IgnoreParenImpCasts();
148 // Arguments which are macros are ignored.
149 if (Arg->getBeginLoc().isMacroID())
150 return;
151
152 // We first handle the cases of literal zero (both float and integer).
153 if (IsLiteralZero(Result, *Arg)) {
154 diag(Call->getBeginLoc(),
155 "use ZeroDuration() for zero-length time intervals")
156 << FixItHint::CreateReplacement(Call->getSourceRange(),
157 "absl::ZeroDuration()");
158 return;
159 }
160
161 const auto *CallDecl = Result.Nodes.getNodeAs<FunctionDecl>("call_decl");
162 llvm::Optional<DurationScale> MaybeScale =
163 getScaleForFactory(CallDecl->getName());
164 if (!MaybeScale)
165 return;
166
167 DurationScale Scale = *MaybeScale;
168 const Expr *Remainder;
169 llvm::Optional<DurationScale> NewScale;
170
171 // We next handle the cases of multiplication and division.
172 if (const auto *MultBinOp =
173 Result.Nodes.getNodeAs<BinaryOperator>("mult_binop")) {
174 // For multiplication, we need to look at both operands, and consider the
175 // cases where a user is multiplying by something such as 1e-3.
176
177 // First check the LHS
178 const auto *IntLit = llvm::dyn_cast<IntegerLiteral>(MultBinOp->getLHS());
179 const auto *FloatLit = llvm::dyn_cast<FloatingLiteral>(MultBinOp->getLHS());
180 if (IntLit || FloatLit) {
181 NewScale = GetNewScale(Scale, GetValue(IntLit, FloatLit));
182 if (NewScale)
183 Remainder = MultBinOp->getRHS();
184 }
185
186 // If we weren't able to scale based on the LHS, check the RHS
187 if (!NewScale) {
188 IntLit = llvm::dyn_cast<IntegerLiteral>(MultBinOp->getRHS());
189 FloatLit = llvm::dyn_cast<FloatingLiteral>(MultBinOp->getRHS());
190 if (IntLit || FloatLit) {
191 NewScale = GetNewScale(Scale, GetValue(IntLit, FloatLit));
192 if (NewScale)
193 Remainder = MultBinOp->getLHS();
194 }
195 }
196 } else if (const auto *DivBinOp =
197 Result.Nodes.getNodeAs<BinaryOperator>("div_binop")) {
198 // We next handle division.
199 // For division, we only check the RHS.
200 const auto *FloatLit = llvm::dyn_cast<FloatingLiteral>(DivBinOp->getRHS());
201
202 llvm::Optional<DurationScale> NewScale =
203 GetNewScale(Scale, 1.0 / FloatLit->getValueAsApproximateDouble());
204 if (NewScale) {
205 const Expr *Remainder = DivBinOp->getLHS();
206
207 // We've found an appropriate scaling factor and the new scale, so output
208 // the relevant fix.
209 diag(Call->getBeginLoc(), "internal duration scaling can be removed")
210 << FixItHint::CreateReplacement(
211 Call->getSourceRange(),
212 (llvm::Twine(getDurationFactoryForScale(*NewScale)) + "(" +
213 tooling::fixit::getText(*Remainder, *Result.Context) + ")")
214 .str());
215 }
216 }
217
218 if (NewScale) {
219 assert(Remainder && "No remainder found");
220 // We've found an appropriate scaling factor and the new scale, so output
221 // the relevant fix.
222 diag(Call->getBeginLoc(), "internal duration scaling can be removed")
223 << FixItHint::CreateReplacement(
224 Call->getSourceRange(),
225 (llvm::Twine(getDurationFactoryForScale(*NewScale)) + "(" +
226 tooling::fixit::getText(*Remainder, *Result.Context) + ")")
227 .str());
228 }
229 return;
230 }
231
232 } // namespace abseil
233 } // namespace tidy
234 } // namespace clang
235