1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Analysis/ScalarEvolution.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/ValueHandle.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include <cassert>
25 #include <cstddef>
26
27 namespace llvm {
28
29 class APInt;
30 class Constant;
31 class ConstantInt;
32 class ConstantRange;
33 class Loop;
34 class Type;
35 class Value;
36
37 enum SCEVTypes : unsigned short {
38 // These should be ordered in terms of increasing complexity to make the
39 // folders simpler.
40 scConstant,
41 scVScale,
42 scTruncate,
43 scZeroExtend,
44 scSignExtend,
45 scAddExpr,
46 scMulExpr,
47 scUDivExpr,
48 scAddRecExpr,
49 scUMaxExpr,
50 scSMaxExpr,
51 scUMinExpr,
52 scSMinExpr,
53 scSequentialUMinExpr,
54 scPtrToInt,
55 scUnknown,
56 scCouldNotCompute
57 };
58
59 /// This class represents a constant integer value.
60 class SCEVConstant : public SCEV {
61 friend class ScalarEvolution;
62
63 ConstantInt *V;
64
SCEVConstant(const FoldingSetNodeIDRef ID,ConstantInt * v)65 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
66 : SCEV(ID, scConstant, 1), V(v) {}
67
68 public:
getValue()69 ConstantInt *getValue() const { return V; }
getAPInt()70 const APInt &getAPInt() const { return getValue()->getValue(); }
71
getType()72 Type *getType() const { return V->getType(); }
73
74 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)75 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
76 };
77
78 /// This class represents the value of vscale, as used when defining the length
79 /// of a scalable vector or returned by the llvm.vscale() intrinsic.
80 class SCEVVScale : public SCEV {
81 friend class ScalarEvolution;
82
SCEVVScale(const FoldingSetNodeIDRef ID,Type * ty)83 SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty)
84 : SCEV(ID, scVScale, 0), Ty(ty) {}
85
86 Type *Ty;
87
88 public:
getType()89 Type *getType() const { return Ty; }
90
91 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)92 static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; }
93 };
94
computeExpressionSize(ArrayRef<const SCEV * > Args)95 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
96 APInt Size(16, 1);
97 for (const auto *Arg : Args)
98 Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
99 return (unsigned short)Size.getZExtValue();
100 }
101
102 /// This is the base class for unary cast operator classes.
103 class SCEVCastExpr : public SCEV {
104 protected:
105 const SCEV *Op;
106 Type *Ty;
107
108 SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
109 Type *ty);
110
111 public:
getOperand()112 const SCEV *getOperand() const { return Op; }
getOperand(unsigned i)113 const SCEV *getOperand(unsigned i) const {
114 assert(i == 0 && "Operand index out of range!");
115 return Op;
116 }
operands()117 ArrayRef<const SCEV *> operands() const { return Op; }
getNumOperands()118 size_t getNumOperands() const { return 1; }
getType()119 Type *getType() const { return Ty; }
120
121 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)122 static bool classof(const SCEV *S) {
123 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
124 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
125 }
126 };
127
128 /// This class represents a cast from a pointer to a pointer-sized integer
129 /// value.
130 class SCEVPtrToIntExpr : public SCEVCastExpr {
131 friend class ScalarEvolution;
132
133 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
134
135 public:
136 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)137 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
138 };
139
140 /// This is the base class for unary integral cast operator classes.
141 class SCEVIntegralCastExpr : public SCEVCastExpr {
142 protected:
143 SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
144 const SCEV *op, Type *ty);
145
146 public:
147 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)148 static bool classof(const SCEV *S) {
149 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
150 S->getSCEVType() == scSignExtend;
151 }
152 };
153
154 /// This class represents a truncation of an integer value to a
155 /// smaller integer value.
156 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
157 friend class ScalarEvolution;
158
159 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
160
161 public:
162 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)163 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
164 };
165
166 /// This class represents a zero extension of a small integer value
167 /// to a larger integer value.
168 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
169 friend class ScalarEvolution;
170
171 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
172
173 public:
174 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)175 static bool classof(const SCEV *S) {
176 return S->getSCEVType() == scZeroExtend;
177 }
178 };
179
180 /// This class represents a sign extension of a small integer value
181 /// to a larger integer value.
182 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
183 friend class ScalarEvolution;
184
185 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
186
187 public:
188 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)189 static bool classof(const SCEV *S) {
190 return S->getSCEVType() == scSignExtend;
191 }
192 };
193
194 /// This node is a base class providing common functionality for
195 /// n'ary operators.
196 class SCEVNAryExpr : public SCEV {
197 protected:
198 // Since SCEVs are immutable, ScalarEvolution allocates operand
199 // arrays with its SCEVAllocator, so this class just needs a simple
200 // pointer rather than a more elaborate vector-like data structure.
201 // This also avoids the need for a non-trivial destructor.
202 const SCEV *const *Operands;
203 size_t NumOperands;
204
SCEVNAryExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)205 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
206 const SCEV *const *O, size_t N)
207 : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O),
208 NumOperands(N) {}
209
210 public:
getNumOperands()211 size_t getNumOperands() const { return NumOperands; }
212
getOperand(unsigned i)213 const SCEV *getOperand(unsigned i) const {
214 assert(i < NumOperands && "Operand index out of range!");
215 return Operands[i];
216 }
217
operands()218 ArrayRef<const SCEV *> operands() const {
219 return ArrayRef(Operands, NumOperands);
220 }
221
222 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
223 return (NoWrapFlags)(SubclassData & Mask);
224 }
225
hasNoUnsignedWrap()226 bool hasNoUnsignedWrap() const {
227 return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
228 }
229
hasNoSignedWrap()230 bool hasNoSignedWrap() const {
231 return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
232 }
233
hasNoSelfWrap()234 bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
235
236 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)237 static bool classof(const SCEV *S) {
238 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
239 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
240 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
241 S->getSCEVType() == scSequentialUMinExpr ||
242 S->getSCEVType() == scAddRecExpr;
243 }
244 };
245
246 /// This node is the base class for n'ary commutative operators.
247 class SCEVCommutativeExpr : public SCEVNAryExpr {
248 protected:
SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)249 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
250 const SCEV *const *O, size_t N)
251 : SCEVNAryExpr(ID, T, O, N) {}
252
253 public:
254 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)255 static bool classof(const SCEV *S) {
256 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
257 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
258 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
259 }
260
261 /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)262 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
263 };
264
265 /// This node represents an addition of some number of SCEVs.
266 class SCEVAddExpr : public SCEVCommutativeExpr {
267 friend class ScalarEvolution;
268
269 Type *Ty;
270
SCEVAddExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)271 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
272 : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
273 auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
274 return Op->getType()->isPointerTy();
275 });
276 if (FirstPointerTypedOp != operands().end())
277 Ty = (*FirstPointerTypedOp)->getType();
278 else
279 Ty = getOperand(0)->getType();
280 }
281
282 public:
getType()283 Type *getType() const { return Ty; }
284
285 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)286 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
287 };
288
289 /// This node represents multiplication of some number of SCEVs.
290 class SCEVMulExpr : public SCEVCommutativeExpr {
291 friend class ScalarEvolution;
292
SCEVMulExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)293 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
294 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
295
296 public:
getType()297 Type *getType() const { return getOperand(0)->getType(); }
298
299 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)300 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
301 };
302
303 /// This class represents a binary unsigned division operation.
304 class SCEVUDivExpr : public SCEV {
305 friend class ScalarEvolution;
306
307 std::array<const SCEV *, 2> Operands;
308
SCEVUDivExpr(const FoldingSetNodeIDRef ID,const SCEV * lhs,const SCEV * rhs)309 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
310 : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
311 Operands[0] = lhs;
312 Operands[1] = rhs;
313 }
314
315 public:
getLHS()316 const SCEV *getLHS() const { return Operands[0]; }
getRHS()317 const SCEV *getRHS() const { return Operands[1]; }
getNumOperands()318 size_t getNumOperands() const { return 2; }
getOperand(unsigned i)319 const SCEV *getOperand(unsigned i) const {
320 assert((i == 0 || i == 1) && "Operand index out of range!");
321 return i == 0 ? getLHS() : getRHS();
322 }
323
operands()324 ArrayRef<const SCEV *> operands() const { return Operands; }
325
getType()326 Type *getType() const {
327 // In most cases the types of LHS and RHS will be the same, but in some
328 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
329 // depend on the type for correctness, but handling types carefully can
330 // avoid extra casts in the SCEVExpander. The LHS is more likely to be
331 // a pointer type than the RHS, so use the RHS' type here.
332 return getRHS()->getType();
333 }
334
335 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)336 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
337 };
338
339 /// This node represents a polynomial recurrence on the trip count
340 /// of the specified loop. This is the primary focus of the
341 /// ScalarEvolution framework; all the other SCEV subclasses are
342 /// mostly just supporting infrastructure to allow SCEVAddRecExpr
343 /// expressions to be created and analyzed.
344 ///
345 /// All operands of an AddRec are required to be loop invariant.
346 ///
347 class SCEVAddRecExpr : public SCEVNAryExpr {
348 friend class ScalarEvolution;
349
350 const Loop *L;
351
SCEVAddRecExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N,const Loop * l)352 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
353 const Loop *l)
354 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
355
356 public:
getType()357 Type *getType() const { return getStart()->getType(); }
getStart()358 const SCEV *getStart() const { return Operands[0]; }
getLoop()359 const Loop *getLoop() const { return L; }
360
361 /// Constructs and returns the recurrence indicating how much this
362 /// expression steps by. If this is a polynomial of degree N, it
363 /// returns a chrec of degree N-1. We cannot determine whether
364 /// the step recurrence has self-wraparound.
getStepRecurrence(ScalarEvolution & SE)365 const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
366 if (isAffine())
367 return getOperand(1);
368 return SE.getAddRecExpr(
369 SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(),
370 FlagAnyWrap);
371 }
372
373 /// Return true if this represents an expression A + B*x where A
374 /// and B are loop invariant values.
isAffine()375 bool isAffine() const {
376 // We know that the start value is invariant. This expression is thus
377 // affine iff the step is also invariant.
378 return getNumOperands() == 2;
379 }
380
381 /// Return true if this represents an expression A + B*x + C*x^2
382 /// where A, B and C are loop invariant values. This corresponds
383 /// to an addrec of the form {L,+,M,+,N}
isQuadratic()384 bool isQuadratic() const { return getNumOperands() == 3; }
385
386 /// Set flags for a recurrence without clearing any previously set flags.
387 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
388 /// to make it easier to propagate flags.
setNoWrapFlags(NoWrapFlags Flags)389 void setNoWrapFlags(NoWrapFlags Flags) {
390 if (Flags & (FlagNUW | FlagNSW))
391 Flags = ScalarEvolution::setFlags(Flags, FlagNW);
392 SubclassData |= Flags;
393 }
394
395 /// Return the value of this chain of recurrences at the specified
396 /// iteration number.
397 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
398
399 /// Return the value of this chain of recurrences at the specified iteration
400 /// number. Takes an explicit list of operands to represent an AddRec.
401 static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
402 const SCEV *It, ScalarEvolution &SE);
403
404 /// Return the number of iterations of this loop that produce
405 /// values in the specified constant range. Another way of
406 /// looking at this is that it returns the first iteration number
407 /// where the value is not in the condition, thus computing the
408 /// exit count. If the iteration count can't be computed, an
409 /// instance of SCEVCouldNotCompute is returned.
410 const SCEV *getNumIterationsInRange(const ConstantRange &Range,
411 ScalarEvolution &SE) const;
412
413 /// Return an expression representing the value of this expression
414 /// one iteration of the loop ahead.
415 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
416
417 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)418 static bool classof(const SCEV *S) {
419 return S->getSCEVType() == scAddRecExpr;
420 }
421 };
422
423 /// This node is the base class min/max selections.
424 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
425 friend class ScalarEvolution;
426
isMinMaxType(enum SCEVTypes T)427 static bool isMinMaxType(enum SCEVTypes T) {
428 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
429 T == scUMinExpr;
430 }
431
432 protected:
433 /// Note: Constructing subclasses via this constructor is allowed
SCEVMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)434 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
435 const SCEV *const *O, size_t N)
436 : SCEVCommutativeExpr(ID, T, O, N) {
437 assert(isMinMaxType(T));
438 // Min and max never overflow
439 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
440 }
441
442 public:
getType()443 Type *getType() const { return getOperand(0)->getType(); }
444
classof(const SCEV * S)445 static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
446
negate(enum SCEVTypes T)447 static enum SCEVTypes negate(enum SCEVTypes T) {
448 switch (T) {
449 case scSMaxExpr:
450 return scSMinExpr;
451 case scSMinExpr:
452 return scSMaxExpr;
453 case scUMaxExpr:
454 return scUMinExpr;
455 case scUMinExpr:
456 return scUMaxExpr;
457 default:
458 llvm_unreachable("Not a min or max SCEV type!");
459 }
460 }
461 };
462
463 /// This class represents a signed maximum selection.
464 class SCEVSMaxExpr : public SCEVMinMaxExpr {
465 friend class ScalarEvolution;
466
SCEVSMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)467 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
468 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
469
470 public:
471 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)472 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
473 };
474
475 /// This class represents an unsigned maximum selection.
476 class SCEVUMaxExpr : public SCEVMinMaxExpr {
477 friend class ScalarEvolution;
478
SCEVUMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)479 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
480 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
481
482 public:
483 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)484 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
485 };
486
487 /// This class represents a signed minimum selection.
488 class SCEVSMinExpr : public SCEVMinMaxExpr {
489 friend class ScalarEvolution;
490
SCEVSMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)491 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
492 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
493
494 public:
495 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)496 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
497 };
498
499 /// This class represents an unsigned minimum selection.
500 class SCEVUMinExpr : public SCEVMinMaxExpr {
501 friend class ScalarEvolution;
502
SCEVUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)503 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
504 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
505
506 public:
507 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)508 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
509 };
510
511 /// This node is the base class for sequential/in-order min/max selections.
512 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they
513 /// are early-returning upon reaching saturation point.
514 /// I.e. given `0 umin_seq poison`, the result will be `0`,
515 /// while the result of `0 umin poison` is `poison`.
516 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
517 friend class ScalarEvolution;
518
isSequentialMinMaxType(enum SCEVTypes T)519 static bool isSequentialMinMaxType(enum SCEVTypes T) {
520 return T == scSequentialUMinExpr;
521 }
522
523 /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)524 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
525
526 protected:
527 /// Note: Constructing subclasses via this constructor is allowed
SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)528 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
529 const SCEV *const *O, size_t N)
530 : SCEVNAryExpr(ID, T, O, N) {
531 assert(isSequentialMinMaxType(T));
532 // Min and max never overflow
533 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
534 }
535
536 public:
getType()537 Type *getType() const { return getOperand(0)->getType(); }
538
getEquivalentNonSequentialSCEVType(SCEVTypes Ty)539 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
540 assert(isSequentialMinMaxType(Ty));
541 switch (Ty) {
542 case scSequentialUMinExpr:
543 return scUMinExpr;
544 default:
545 llvm_unreachable("Not a sequential min/max type.");
546 }
547 }
548
getEquivalentNonSequentialSCEVType()549 SCEVTypes getEquivalentNonSequentialSCEVType() const {
550 return getEquivalentNonSequentialSCEVType(getSCEVType());
551 }
552
classof(const SCEV * S)553 static bool classof(const SCEV *S) {
554 return isSequentialMinMaxType(S->getSCEVType());
555 }
556 };
557
558 /// This class represents a sequential/in-order unsigned minimum selection.
559 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
560 friend class ScalarEvolution;
561
SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)562 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
563 size_t N)
564 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
565
566 public:
567 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)568 static bool classof(const SCEV *S) {
569 return S->getSCEVType() == scSequentialUMinExpr;
570 }
571 };
572
573 /// This means that we are dealing with an entirely unknown SCEV
574 /// value, and only represent it as its LLVM Value. This is the
575 /// "bottom" value for the analysis.
576 class SCEVUnknown final : public SCEV, private CallbackVH {
577 friend class ScalarEvolution;
578
579 /// The parent ScalarEvolution value. This is used to update the
580 /// parent's maps when the value associated with a SCEVUnknown is
581 /// deleted or RAUW'd.
582 ScalarEvolution *SE;
583
584 /// The next pointer in the linked list of all SCEVUnknown
585 /// instances owned by a ScalarEvolution.
586 SCEVUnknown *Next;
587
SCEVUnknown(const FoldingSetNodeIDRef ID,Value * V,ScalarEvolution * se,SCEVUnknown * next)588 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
589 SCEVUnknown *next)
590 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
591
592 // Implement CallbackVH.
593 void deleted() override;
594 void allUsesReplacedWith(Value *New) override;
595
596 public:
getValue()597 Value *getValue() const { return getValPtr(); }
598
getType()599 Type *getType() const { return getValPtr()->getType(); }
600
601 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)602 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
603 };
604
605 /// This class defines a simple visitor class that may be used for
606 /// various SCEV analysis purposes.
607 template <typename SC, typename RetVal = void> struct SCEVVisitor {
visitSCEVVisitor608 RetVal visit(const SCEV *S) {
609 switch (S->getSCEVType()) {
610 case scConstant:
611 return ((SC *)this)->visitConstant((const SCEVConstant *)S);
612 case scVScale:
613 return ((SC *)this)->visitVScale((const SCEVVScale *)S);
614 case scPtrToInt:
615 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
616 case scTruncate:
617 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
618 case scZeroExtend:
619 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
620 case scSignExtend:
621 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
622 case scAddExpr:
623 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
624 case scMulExpr:
625 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
626 case scUDivExpr:
627 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
628 case scAddRecExpr:
629 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
630 case scSMaxExpr:
631 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
632 case scUMaxExpr:
633 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
634 case scSMinExpr:
635 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
636 case scUMinExpr:
637 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
638 case scSequentialUMinExpr:
639 return ((SC *)this)
640 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
641 case scUnknown:
642 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
643 case scCouldNotCompute:
644 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
645 }
646 llvm_unreachable("Unknown SCEV kind!");
647 }
648
visitCouldNotComputeSCEVVisitor649 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
650 llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
651 }
652 };
653
654 /// Visit all nodes in the expression tree using worklist traversal.
655 ///
656 /// Visitor implements:
657 /// // return true to follow this node.
658 /// bool follow(const SCEV *S);
659 /// // return true to terminate the search.
660 /// bool isDone();
661 template <typename SV> class SCEVTraversal {
662 SV &Visitor;
663 SmallVector<const SCEV *, 8> Worklist;
664 SmallPtrSet<const SCEV *, 8> Visited;
665
push(const SCEV * S)666 void push(const SCEV *S) {
667 if (Visited.insert(S).second && Visitor.follow(S))
668 Worklist.push_back(S);
669 }
670
671 public:
SCEVTraversal(SV & V)672 SCEVTraversal(SV &V) : Visitor(V) {}
673
visitAll(const SCEV * Root)674 void visitAll(const SCEV *Root) {
675 push(Root);
676 while (!Worklist.empty() && !Visitor.isDone()) {
677 const SCEV *S = Worklist.pop_back_val();
678
679 switch (S->getSCEVType()) {
680 case scConstant:
681 case scVScale:
682 case scUnknown:
683 continue;
684 case scPtrToInt:
685 case scTruncate:
686 case scZeroExtend:
687 case scSignExtend:
688 case scAddExpr:
689 case scMulExpr:
690 case scUDivExpr:
691 case scSMaxExpr:
692 case scUMaxExpr:
693 case scSMinExpr:
694 case scUMinExpr:
695 case scSequentialUMinExpr:
696 case scAddRecExpr:
697 for (const auto *Op : S->operands()) {
698 push(Op);
699 if (Visitor.isDone())
700 break;
701 }
702 continue;
703 case scCouldNotCompute:
704 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
705 }
706 llvm_unreachable("Unknown SCEV kind!");
707 }
708 }
709 };
710
711 /// Use SCEVTraversal to visit all nodes in the given expression tree.
visitAll(const SCEV * Root,SV & Visitor)712 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
713 SCEVTraversal<SV> T(Visitor);
714 T.visitAll(Root);
715 }
716
717 /// Return true if any node in \p Root satisfies the predicate \p Pred.
718 template <typename PredTy>
SCEVExprContains(const SCEV * Root,PredTy Pred)719 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
720 struct FindClosure {
721 bool Found = false;
722 PredTy Pred;
723
724 FindClosure(PredTy Pred) : Pred(Pred) {}
725
726 bool follow(const SCEV *S) {
727 if (!Pred(S))
728 return true;
729
730 Found = true;
731 return false;
732 }
733
734 bool isDone() const { return Found; }
735 };
736
737 FindClosure FC(Pred);
738 visitAll(Root, FC);
739 return FC.Found;
740 }
741
742 /// This visitor recursively visits a SCEV expression and re-writes it.
743 /// The result from each visit is cached, so it will return the same
744 /// SCEV for the same input.
745 template <typename SC>
746 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
747 protected:
748 ScalarEvolution &SE;
749 // Memoize the result of each visit so that we only compute once for
750 // the same input SCEV. This is to avoid redundant computations when
751 // a SCEV is referenced by multiple SCEVs. Without memoization, this
752 // visit algorithm would have exponential time complexity in the worst
753 // case, causing the compiler to hang on certain tests.
754 SmallDenseMap<const SCEV *, const SCEV *> RewriteResults;
755
756 public:
SCEVRewriteVisitor(ScalarEvolution & SE)757 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
758
visit(const SCEV * S)759 const SCEV *visit(const SCEV *S) {
760 auto It = RewriteResults.find(S);
761 if (It != RewriteResults.end())
762 return It->second;
763 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
764 auto Result = RewriteResults.try_emplace(S, Visited);
765 assert(Result.second && "Should insert a new entry");
766 return Result.first->second;
767 }
768
visitConstant(const SCEVConstant * Constant)769 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
770
visitVScale(const SCEVVScale * VScale)771 const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }
772
visitPtrToIntExpr(const SCEVPtrToIntExpr * Expr)773 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
774 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
775 return Operand == Expr->getOperand()
776 ? Expr
777 : SE.getPtrToIntExpr(Operand, Expr->getType());
778 }
779
visitTruncateExpr(const SCEVTruncateExpr * Expr)780 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
781 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
782 return Operand == Expr->getOperand()
783 ? Expr
784 : SE.getTruncateExpr(Operand, Expr->getType());
785 }
786
visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr)787 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
788 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
789 return Operand == Expr->getOperand()
790 ? Expr
791 : SE.getZeroExtendExpr(Operand, Expr->getType());
792 }
793
visitSignExtendExpr(const SCEVSignExtendExpr * Expr)794 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
795 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
796 return Operand == Expr->getOperand()
797 ? Expr
798 : SE.getSignExtendExpr(Operand, Expr->getType());
799 }
800
visitAddExpr(const SCEVAddExpr * Expr)801 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
802 SmallVector<const SCEV *, 2> Operands;
803 bool Changed = false;
804 for (const auto *Op : Expr->operands()) {
805 Operands.push_back(((SC *)this)->visit(Op));
806 Changed |= Op != Operands.back();
807 }
808 return !Changed ? Expr : SE.getAddExpr(Operands);
809 }
810
visitMulExpr(const SCEVMulExpr * Expr)811 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
812 SmallVector<const SCEV *, 2> Operands;
813 bool Changed = false;
814 for (const auto *Op : Expr->operands()) {
815 Operands.push_back(((SC *)this)->visit(Op));
816 Changed |= Op != Operands.back();
817 }
818 return !Changed ? Expr : SE.getMulExpr(Operands);
819 }
820
visitUDivExpr(const SCEVUDivExpr * Expr)821 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
822 auto *LHS = ((SC *)this)->visit(Expr->getLHS());
823 auto *RHS = ((SC *)this)->visit(Expr->getRHS());
824 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
825 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
826 }
827
visitAddRecExpr(const SCEVAddRecExpr * Expr)828 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
829 SmallVector<const SCEV *, 2> Operands;
830 bool Changed = false;
831 for (const auto *Op : Expr->operands()) {
832 Operands.push_back(((SC *)this)->visit(Op));
833 Changed |= Op != Operands.back();
834 }
835 return !Changed ? Expr
836 : SE.getAddRecExpr(Operands, Expr->getLoop(),
837 Expr->getNoWrapFlags());
838 }
839
visitSMaxExpr(const SCEVSMaxExpr * Expr)840 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
841 SmallVector<const SCEV *, 2> Operands;
842 bool Changed = false;
843 for (const auto *Op : Expr->operands()) {
844 Operands.push_back(((SC *)this)->visit(Op));
845 Changed |= Op != Operands.back();
846 }
847 return !Changed ? Expr : SE.getSMaxExpr(Operands);
848 }
849
visitUMaxExpr(const SCEVUMaxExpr * Expr)850 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
851 SmallVector<const SCEV *, 2> Operands;
852 bool Changed = false;
853 for (const auto *Op : Expr->operands()) {
854 Operands.push_back(((SC *)this)->visit(Op));
855 Changed |= Op != Operands.back();
856 }
857 return !Changed ? Expr : SE.getUMaxExpr(Operands);
858 }
859
visitSMinExpr(const SCEVSMinExpr * Expr)860 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
861 SmallVector<const SCEV *, 2> Operands;
862 bool Changed = false;
863 for (const auto *Op : Expr->operands()) {
864 Operands.push_back(((SC *)this)->visit(Op));
865 Changed |= Op != Operands.back();
866 }
867 return !Changed ? Expr : SE.getSMinExpr(Operands);
868 }
869
visitUMinExpr(const SCEVUMinExpr * Expr)870 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
871 SmallVector<const SCEV *, 2> Operands;
872 bool Changed = false;
873 for (const auto *Op : Expr->operands()) {
874 Operands.push_back(((SC *)this)->visit(Op));
875 Changed |= Op != Operands.back();
876 }
877 return !Changed ? Expr : SE.getUMinExpr(Operands);
878 }
879
visitSequentialUMinExpr(const SCEVSequentialUMinExpr * Expr)880 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
881 SmallVector<const SCEV *, 2> Operands;
882 bool Changed = false;
883 for (const auto *Op : Expr->operands()) {
884 Operands.push_back(((SC *)this)->visit(Op));
885 Changed |= Op != Operands.back();
886 }
887 return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
888 }
889
visitUnknown(const SCEVUnknown * Expr)890 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
891
visitCouldNotCompute(const SCEVCouldNotCompute * Expr)892 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
893 return Expr;
894 }
895 };
896
897 using ValueToValueMap = DenseMap<const Value *, Value *>;
898 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
899
900 /// The SCEVParameterRewriter takes a scalar evolution expression and updates
901 /// the SCEVUnknown components following the Map (Value -> SCEV).
902 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
903 public:
rewrite(const SCEV * Scev,ScalarEvolution & SE,ValueToSCEVMapTy & Map)904 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
905 ValueToSCEVMapTy &Map) {
906 SCEVParameterRewriter Rewriter(SE, Map);
907 return Rewriter.visit(Scev);
908 }
909
SCEVParameterRewriter(ScalarEvolution & SE,ValueToSCEVMapTy & M)910 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
911 : SCEVRewriteVisitor(SE), Map(M) {}
912
visitUnknown(const SCEVUnknown * Expr)913 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
914 auto I = Map.find(Expr->getValue());
915 if (I == Map.end())
916 return Expr;
917 return I->second;
918 }
919
920 private:
921 ValueToSCEVMapTy ⤅
922 };
923
924 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
925
926 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
927 /// the Map (Loop -> SCEV) to all AddRecExprs.
928 class SCEVLoopAddRecRewriter
929 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
930 public:
SCEVLoopAddRecRewriter(ScalarEvolution & SE,LoopToScevMapT & M)931 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
932 : SCEVRewriteVisitor(SE), Map(M) {}
933
rewrite(const SCEV * Scev,LoopToScevMapT & Map,ScalarEvolution & SE)934 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
935 ScalarEvolution &SE) {
936 SCEVLoopAddRecRewriter Rewriter(SE, Map);
937 return Rewriter.visit(Scev);
938 }
939
visitAddRecExpr(const SCEVAddRecExpr * Expr)940 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
941 SmallVector<const SCEV *, 2> Operands;
942 for (const SCEV *Op : Expr->operands())
943 Operands.push_back(visit(Op));
944
945 const Loop *L = Expr->getLoop();
946 if (0 == Map.count(L))
947 return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
948
949 return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
950 }
951
952 private:
953 LoopToScevMapT ⤅
954 };
955
956 } // end namespace llvm
957
958 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
959