1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/iterator_range.h"
20 #include "llvm/Analysis/ScalarEvolution.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/ValueHandle.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include <cassert>
26 #include <cstddef>
27
28 namespace llvm {
29
30 class APInt;
31 class Constant;
32 class ConstantInt;
33 class ConstantRange;
34 class Loop;
35 class Type;
36 class Value;
37
38 enum SCEVTypes : unsigned short {
39 // These should be ordered in terms of increasing complexity to make the
40 // folders simpler.
41 scConstant,
42 scTruncate,
43 scZeroExtend,
44 scSignExtend,
45 scAddExpr,
46 scMulExpr,
47 scUDivExpr,
48 scAddRecExpr,
49 scUMaxExpr,
50 scSMaxExpr,
51 scUMinExpr,
52 scSMinExpr,
53 scSequentialUMinExpr,
54 scPtrToInt,
55 scUnknown,
56 scCouldNotCompute
57 };
58
59 /// This class represents a constant integer value.
60 class SCEVConstant : public SCEV {
61 friend class ScalarEvolution;
62
63 ConstantInt *V;
64
SCEVConstant(const FoldingSetNodeIDRef ID,ConstantInt * v)65 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
66 : SCEV(ID, scConstant, 1), V(v) {}
67
68 public:
getValue()69 ConstantInt *getValue() const { return V; }
getAPInt()70 const APInt &getAPInt() const { return getValue()->getValue(); }
71
getType()72 Type *getType() const { return V->getType(); }
73
74 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)75 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
76 };
77
computeExpressionSize(ArrayRef<const SCEV * > Args)78 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
79 APInt Size(16, 1);
80 for (const auto *Arg : Args)
81 Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
82 return (unsigned short)Size.getZExtValue();
83 }
84
85 /// This is the base class for unary cast operator classes.
86 class SCEVCastExpr : public SCEV {
87 protected:
88 const SCEV *Op;
89 Type *Ty;
90
91 SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
92 Type *ty);
93
94 public:
getOperand()95 const SCEV *getOperand() const { return Op; }
getOperand(unsigned i)96 const SCEV *getOperand(unsigned i) const {
97 assert(i == 0 && "Operand index out of range!");
98 return Op;
99 }
operands()100 ArrayRef<const SCEV *> operands() const { return Op; }
getNumOperands()101 size_t getNumOperands() const { return 1; }
getType()102 Type *getType() const { return Ty; }
103
104 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)105 static bool classof(const SCEV *S) {
106 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
107 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
108 }
109 };
110
111 /// This class represents a cast from a pointer to a pointer-sized integer
112 /// value.
113 class SCEVPtrToIntExpr : public SCEVCastExpr {
114 friend class ScalarEvolution;
115
116 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
117
118 public:
119 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)120 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
121 };
122
123 /// This is the base class for unary integral cast operator classes.
124 class SCEVIntegralCastExpr : public SCEVCastExpr {
125 protected:
126 SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
127 const SCEV *op, Type *ty);
128
129 public:
130 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)131 static bool classof(const SCEV *S) {
132 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
133 S->getSCEVType() == scSignExtend;
134 }
135 };
136
137 /// This class represents a truncation of an integer value to a
138 /// smaller integer value.
139 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
140 friend class ScalarEvolution;
141
142 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
143
144 public:
145 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)146 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
147 };
148
149 /// This class represents a zero extension of a small integer value
150 /// to a larger integer value.
151 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
152 friend class ScalarEvolution;
153
154 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
155
156 public:
157 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)158 static bool classof(const SCEV *S) {
159 return S->getSCEVType() == scZeroExtend;
160 }
161 };
162
163 /// This class represents a sign extension of a small integer value
164 /// to a larger integer value.
165 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
166 friend class ScalarEvolution;
167
168 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
169
170 public:
171 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)172 static bool classof(const SCEV *S) {
173 return S->getSCEVType() == scSignExtend;
174 }
175 };
176
177 /// This node is a base class providing common functionality for
178 /// n'ary operators.
179 class SCEVNAryExpr : public SCEV {
180 protected:
181 // Since SCEVs are immutable, ScalarEvolution allocates operand
182 // arrays with its SCEVAllocator, so this class just needs a simple
183 // pointer rather than a more elaborate vector-like data structure.
184 // This also avoids the need for a non-trivial destructor.
185 const SCEV *const *Operands;
186 size_t NumOperands;
187
SCEVNAryExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)188 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
189 const SCEV *const *O, size_t N)
190 : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O),
191 NumOperands(N) {}
192
193 public:
getNumOperands()194 size_t getNumOperands() const { return NumOperands; }
195
getOperand(unsigned i)196 const SCEV *getOperand(unsigned i) const {
197 assert(i < NumOperands && "Operand index out of range!");
198 return Operands[i];
199 }
200
operands()201 ArrayRef<const SCEV *> operands() const {
202 return ArrayRef(Operands, NumOperands);
203 }
204
205 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
206 return (NoWrapFlags)(SubclassData & Mask);
207 }
208
hasNoUnsignedWrap()209 bool hasNoUnsignedWrap() const {
210 return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
211 }
212
hasNoSignedWrap()213 bool hasNoSignedWrap() const {
214 return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
215 }
216
hasNoSelfWrap()217 bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
218
219 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)220 static bool classof(const SCEV *S) {
221 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
222 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
223 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
224 S->getSCEVType() == scSequentialUMinExpr ||
225 S->getSCEVType() == scAddRecExpr;
226 }
227 };
228
229 /// This node is the base class for n'ary commutative operators.
230 class SCEVCommutativeExpr : public SCEVNAryExpr {
231 protected:
SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)232 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
233 const SCEV *const *O, size_t N)
234 : SCEVNAryExpr(ID, T, O, N) {}
235
236 public:
237 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)238 static bool classof(const SCEV *S) {
239 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
240 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
241 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
242 }
243
244 /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)245 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
246 };
247
248 /// This node represents an addition of some number of SCEVs.
249 class SCEVAddExpr : public SCEVCommutativeExpr {
250 friend class ScalarEvolution;
251
252 Type *Ty;
253
SCEVAddExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)254 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
255 : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
256 auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
257 return Op->getType()->isPointerTy();
258 });
259 if (FirstPointerTypedOp != operands().end())
260 Ty = (*FirstPointerTypedOp)->getType();
261 else
262 Ty = getOperand(0)->getType();
263 }
264
265 public:
getType()266 Type *getType() const { return Ty; }
267
268 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)269 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
270 };
271
272 /// This node represents multiplication of some number of SCEVs.
273 class SCEVMulExpr : public SCEVCommutativeExpr {
274 friend class ScalarEvolution;
275
SCEVMulExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)276 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
277 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
278
279 public:
getType()280 Type *getType() const { return getOperand(0)->getType(); }
281
282 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)283 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
284 };
285
286 /// This class represents a binary unsigned division operation.
287 class SCEVUDivExpr : public SCEV {
288 friend class ScalarEvolution;
289
290 std::array<const SCEV *, 2> Operands;
291
SCEVUDivExpr(const FoldingSetNodeIDRef ID,const SCEV * lhs,const SCEV * rhs)292 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
293 : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
294 Operands[0] = lhs;
295 Operands[1] = rhs;
296 }
297
298 public:
getLHS()299 const SCEV *getLHS() const { return Operands[0]; }
getRHS()300 const SCEV *getRHS() const { return Operands[1]; }
getNumOperands()301 size_t getNumOperands() const { return 2; }
getOperand(unsigned i)302 const SCEV *getOperand(unsigned i) const {
303 assert((i == 0 || i == 1) && "Operand index out of range!");
304 return i == 0 ? getLHS() : getRHS();
305 }
306
operands()307 ArrayRef<const SCEV *> operands() const { return Operands; }
308
getType()309 Type *getType() const {
310 // In most cases the types of LHS and RHS will be the same, but in some
311 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
312 // depend on the type for correctness, but handling types carefully can
313 // avoid extra casts in the SCEVExpander. The LHS is more likely to be
314 // a pointer type than the RHS, so use the RHS' type here.
315 return getRHS()->getType();
316 }
317
318 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)319 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
320 };
321
322 /// This node represents a polynomial recurrence on the trip count
323 /// of the specified loop. This is the primary focus of the
324 /// ScalarEvolution framework; all the other SCEV subclasses are
325 /// mostly just supporting infrastructure to allow SCEVAddRecExpr
326 /// expressions to be created and analyzed.
327 ///
328 /// All operands of an AddRec are required to be loop invariant.
329 ///
330 class SCEVAddRecExpr : public SCEVNAryExpr {
331 friend class ScalarEvolution;
332
333 const Loop *L;
334
SCEVAddRecExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N,const Loop * l)335 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
336 const Loop *l)
337 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
338
339 public:
getType()340 Type *getType() const { return getStart()->getType(); }
getStart()341 const SCEV *getStart() const { return Operands[0]; }
getLoop()342 const Loop *getLoop() const { return L; }
343
344 /// Constructs and returns the recurrence indicating how much this
345 /// expression steps by. If this is a polynomial of degree N, it
346 /// returns a chrec of degree N-1. We cannot determine whether
347 /// the step recurrence has self-wraparound.
getStepRecurrence(ScalarEvolution & SE)348 const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
349 if (isAffine())
350 return getOperand(1);
351 return SE.getAddRecExpr(
352 SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(),
353 FlagAnyWrap);
354 }
355
356 /// Return true if this represents an expression A + B*x where A
357 /// and B are loop invariant values.
isAffine()358 bool isAffine() const {
359 // We know that the start value is invariant. This expression is thus
360 // affine iff the step is also invariant.
361 return getNumOperands() == 2;
362 }
363
364 /// Return true if this represents an expression A + B*x + C*x^2
365 /// where A, B and C are loop invariant values. This corresponds
366 /// to an addrec of the form {L,+,M,+,N}
isQuadratic()367 bool isQuadratic() const { return getNumOperands() == 3; }
368
369 /// Set flags for a recurrence without clearing any previously set flags.
370 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
371 /// to make it easier to propagate flags.
setNoWrapFlags(NoWrapFlags Flags)372 void setNoWrapFlags(NoWrapFlags Flags) {
373 if (Flags & (FlagNUW | FlagNSW))
374 Flags = ScalarEvolution::setFlags(Flags, FlagNW);
375 SubclassData |= Flags;
376 }
377
378 /// Return the value of this chain of recurrences at the specified
379 /// iteration number.
380 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
381
382 /// Return the value of this chain of recurrences at the specified iteration
383 /// number. Takes an explicit list of operands to represent an AddRec.
384 static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
385 const SCEV *It, ScalarEvolution &SE);
386
387 /// Return the number of iterations of this loop that produce
388 /// values in the specified constant range. Another way of
389 /// looking at this is that it returns the first iteration number
390 /// where the value is not in the condition, thus computing the
391 /// exit count. If the iteration count can't be computed, an
392 /// instance of SCEVCouldNotCompute is returned.
393 const SCEV *getNumIterationsInRange(const ConstantRange &Range,
394 ScalarEvolution &SE) const;
395
396 /// Return an expression representing the value of this expression
397 /// one iteration of the loop ahead.
398 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
399
400 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)401 static bool classof(const SCEV *S) {
402 return S->getSCEVType() == scAddRecExpr;
403 }
404 };
405
406 /// This node is the base class min/max selections.
407 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
408 friend class ScalarEvolution;
409
isMinMaxType(enum SCEVTypes T)410 static bool isMinMaxType(enum SCEVTypes T) {
411 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
412 T == scUMinExpr;
413 }
414
415 protected:
416 /// Note: Constructing subclasses via this constructor is allowed
SCEVMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)417 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
418 const SCEV *const *O, size_t N)
419 : SCEVCommutativeExpr(ID, T, O, N) {
420 assert(isMinMaxType(T));
421 // Min and max never overflow
422 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
423 }
424
425 public:
getType()426 Type *getType() const { return getOperand(0)->getType(); }
427
classof(const SCEV * S)428 static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
429
negate(enum SCEVTypes T)430 static enum SCEVTypes negate(enum SCEVTypes T) {
431 switch (T) {
432 case scSMaxExpr:
433 return scSMinExpr;
434 case scSMinExpr:
435 return scSMaxExpr;
436 case scUMaxExpr:
437 return scUMinExpr;
438 case scUMinExpr:
439 return scUMaxExpr;
440 default:
441 llvm_unreachable("Not a min or max SCEV type!");
442 }
443 }
444 };
445
446 /// This class represents a signed maximum selection.
447 class SCEVSMaxExpr : public SCEVMinMaxExpr {
448 friend class ScalarEvolution;
449
SCEVSMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)450 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
451 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
452
453 public:
454 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)455 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
456 };
457
458 /// This class represents an unsigned maximum selection.
459 class SCEVUMaxExpr : public SCEVMinMaxExpr {
460 friend class ScalarEvolution;
461
SCEVUMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)462 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
463 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
464
465 public:
466 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)467 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
468 };
469
470 /// This class represents a signed minimum selection.
471 class SCEVSMinExpr : public SCEVMinMaxExpr {
472 friend class ScalarEvolution;
473
SCEVSMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)474 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
475 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
476
477 public:
478 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)479 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
480 };
481
482 /// This class represents an unsigned minimum selection.
483 class SCEVUMinExpr : public SCEVMinMaxExpr {
484 friend class ScalarEvolution;
485
SCEVUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)486 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
487 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
488
489 public:
490 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)491 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
492 };
493
494 /// This node is the base class for sequential/in-order min/max selections.
495 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they
496 /// are early-returning upon reaching saturation point.
497 /// I.e. given `0 umin_seq poison`, the result will be `0`,
498 /// while the result of `0 umin poison` is `poison`.
499 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
500 friend class ScalarEvolution;
501
isSequentialMinMaxType(enum SCEVTypes T)502 static bool isSequentialMinMaxType(enum SCEVTypes T) {
503 return T == scSequentialUMinExpr;
504 }
505
506 /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)507 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
508
509 protected:
510 /// Note: Constructing subclasses via this constructor is allowed
SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)511 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
512 const SCEV *const *O, size_t N)
513 : SCEVNAryExpr(ID, T, O, N) {
514 assert(isSequentialMinMaxType(T));
515 // Min and max never overflow
516 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
517 }
518
519 public:
getType()520 Type *getType() const { return getOperand(0)->getType(); }
521
getEquivalentNonSequentialSCEVType(SCEVTypes Ty)522 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
523 assert(isSequentialMinMaxType(Ty));
524 switch (Ty) {
525 case scSequentialUMinExpr:
526 return scUMinExpr;
527 default:
528 llvm_unreachable("Not a sequential min/max type.");
529 }
530 }
531
getEquivalentNonSequentialSCEVType()532 SCEVTypes getEquivalentNonSequentialSCEVType() const {
533 return getEquivalentNonSequentialSCEVType(getSCEVType());
534 }
535
classof(const SCEV * S)536 static bool classof(const SCEV *S) {
537 return isSequentialMinMaxType(S->getSCEVType());
538 }
539 };
540
541 /// This class represents a sequential/in-order unsigned minimum selection.
542 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
543 friend class ScalarEvolution;
544
SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)545 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
546 size_t N)
547 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
548
549 public:
550 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)551 static bool classof(const SCEV *S) {
552 return S->getSCEVType() == scSequentialUMinExpr;
553 }
554 };
555
556 /// This means that we are dealing with an entirely unknown SCEV
557 /// value, and only represent it as its LLVM Value. This is the
558 /// "bottom" value for the analysis.
559 class SCEVUnknown final : public SCEV, private CallbackVH {
560 friend class ScalarEvolution;
561
562 /// The parent ScalarEvolution value. This is used to update the
563 /// parent's maps when the value associated with a SCEVUnknown is
564 /// deleted or RAUW'd.
565 ScalarEvolution *SE;
566
567 /// The next pointer in the linked list of all SCEVUnknown
568 /// instances owned by a ScalarEvolution.
569 SCEVUnknown *Next;
570
SCEVUnknown(const FoldingSetNodeIDRef ID,Value * V,ScalarEvolution * se,SCEVUnknown * next)571 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
572 SCEVUnknown *next)
573 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
574
575 // Implement CallbackVH.
576 void deleted() override;
577 void allUsesReplacedWith(Value *New) override;
578
579 public:
getValue()580 Value *getValue() const { return getValPtr(); }
581
582 /// @{
583 /// Test whether this is a special constant representing a type
584 /// size, alignment, or field offset in a target-independent
585 /// manner, and hasn't happened to have been folded with other
586 /// operations into something unrecognizable. This is mainly only
587 /// useful for pretty-printing and other situations where it isn't
588 /// absolutely required for these to succeed.
589 bool isSizeOf(Type *&AllocTy) const;
590 bool isAlignOf(Type *&AllocTy) const;
591 bool isOffsetOf(Type *&STy, Constant *&FieldNo) const;
592 /// @}
593
getType()594 Type *getType() const { return getValPtr()->getType(); }
595
596 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)597 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
598 };
599
600 /// This class defines a simple visitor class that may be used for
601 /// various SCEV analysis purposes.
602 template <typename SC, typename RetVal = void> struct SCEVVisitor {
visitSCEVVisitor603 RetVal visit(const SCEV *S) {
604 switch (S->getSCEVType()) {
605 case scConstant:
606 return ((SC *)this)->visitConstant((const SCEVConstant *)S);
607 case scPtrToInt:
608 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
609 case scTruncate:
610 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
611 case scZeroExtend:
612 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
613 case scSignExtend:
614 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
615 case scAddExpr:
616 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
617 case scMulExpr:
618 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
619 case scUDivExpr:
620 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
621 case scAddRecExpr:
622 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
623 case scSMaxExpr:
624 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
625 case scUMaxExpr:
626 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
627 case scSMinExpr:
628 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
629 case scUMinExpr:
630 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
631 case scSequentialUMinExpr:
632 return ((SC *)this)
633 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
634 case scUnknown:
635 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
636 case scCouldNotCompute:
637 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
638 }
639 llvm_unreachable("Unknown SCEV kind!");
640 }
641
visitCouldNotComputeSCEVVisitor642 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
643 llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
644 }
645 };
646
647 /// Visit all nodes in the expression tree using worklist traversal.
648 ///
649 /// Visitor implements:
650 /// // return true to follow this node.
651 /// bool follow(const SCEV *S);
652 /// // return true to terminate the search.
653 /// bool isDone();
654 template <typename SV> class SCEVTraversal {
655 SV &Visitor;
656 SmallVector<const SCEV *, 8> Worklist;
657 SmallPtrSet<const SCEV *, 8> Visited;
658
push(const SCEV * S)659 void push(const SCEV *S) {
660 if (Visited.insert(S).second && Visitor.follow(S))
661 Worklist.push_back(S);
662 }
663
664 public:
SCEVTraversal(SV & V)665 SCEVTraversal(SV &V) : Visitor(V) {}
666
visitAll(const SCEV * Root)667 void visitAll(const SCEV *Root) {
668 push(Root);
669 while (!Worklist.empty() && !Visitor.isDone()) {
670 const SCEV *S = Worklist.pop_back_val();
671
672 switch (S->getSCEVType()) {
673 case scConstant:
674 case scUnknown:
675 continue;
676 case scPtrToInt:
677 case scTruncate:
678 case scZeroExtend:
679 case scSignExtend:
680 case scAddExpr:
681 case scMulExpr:
682 case scUDivExpr:
683 case scSMaxExpr:
684 case scUMaxExpr:
685 case scSMinExpr:
686 case scUMinExpr:
687 case scSequentialUMinExpr:
688 case scAddRecExpr:
689 for (const auto *Op : S->operands()) {
690 push(Op);
691 if (Visitor.isDone())
692 break;
693 }
694 continue;
695 case scCouldNotCompute:
696 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
697 }
698 llvm_unreachable("Unknown SCEV kind!");
699 }
700 }
701 };
702
703 /// Use SCEVTraversal to visit all nodes in the given expression tree.
visitAll(const SCEV * Root,SV & Visitor)704 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
705 SCEVTraversal<SV> T(Visitor);
706 T.visitAll(Root);
707 }
708
709 /// Return true if any node in \p Root satisfies the predicate \p Pred.
710 template <typename PredTy>
SCEVExprContains(const SCEV * Root,PredTy Pred)711 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
712 struct FindClosure {
713 bool Found = false;
714 PredTy Pred;
715
716 FindClosure(PredTy Pred) : Pred(Pred) {}
717
718 bool follow(const SCEV *S) {
719 if (!Pred(S))
720 return true;
721
722 Found = true;
723 return false;
724 }
725
726 bool isDone() const { return Found; }
727 };
728
729 FindClosure FC(Pred);
730 visitAll(Root, FC);
731 return FC.Found;
732 }
733
734 /// This visitor recursively visits a SCEV expression and re-writes it.
735 /// The result from each visit is cached, so it will return the same
736 /// SCEV for the same input.
737 template <typename SC>
738 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
739 protected:
740 ScalarEvolution &SE;
741 // Memoize the result of each visit so that we only compute once for
742 // the same input SCEV. This is to avoid redundant computations when
743 // a SCEV is referenced by multiple SCEVs. Without memoization, this
744 // visit algorithm would have exponential time complexity in the worst
745 // case, causing the compiler to hang on certain tests.
746 DenseMap<const SCEV *, const SCEV *> RewriteResults;
747
748 public:
SCEVRewriteVisitor(ScalarEvolution & SE)749 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
750
visit(const SCEV * S)751 const SCEV *visit(const SCEV *S) {
752 auto It = RewriteResults.find(S);
753 if (It != RewriteResults.end())
754 return It->second;
755 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
756 auto Result = RewriteResults.try_emplace(S, Visited);
757 assert(Result.second && "Should insert a new entry");
758 return Result.first->second;
759 }
760
visitConstant(const SCEVConstant * Constant)761 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
762
visitPtrToIntExpr(const SCEVPtrToIntExpr * Expr)763 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
764 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
765 return Operand == Expr->getOperand()
766 ? Expr
767 : SE.getPtrToIntExpr(Operand, Expr->getType());
768 }
769
visitTruncateExpr(const SCEVTruncateExpr * Expr)770 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
771 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
772 return Operand == Expr->getOperand()
773 ? Expr
774 : SE.getTruncateExpr(Operand, Expr->getType());
775 }
776
visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr)777 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
778 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
779 return Operand == Expr->getOperand()
780 ? Expr
781 : SE.getZeroExtendExpr(Operand, Expr->getType());
782 }
783
visitSignExtendExpr(const SCEVSignExtendExpr * Expr)784 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
785 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
786 return Operand == Expr->getOperand()
787 ? Expr
788 : SE.getSignExtendExpr(Operand, Expr->getType());
789 }
790
visitAddExpr(const SCEVAddExpr * Expr)791 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
792 SmallVector<const SCEV *, 2> Operands;
793 bool Changed = false;
794 for (const auto *Op : Expr->operands()) {
795 Operands.push_back(((SC *)this)->visit(Op));
796 Changed |= Op != Operands.back();
797 }
798 return !Changed ? Expr : SE.getAddExpr(Operands);
799 }
800
visitMulExpr(const SCEVMulExpr * Expr)801 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
802 SmallVector<const SCEV *, 2> Operands;
803 bool Changed = false;
804 for (const auto *Op : Expr->operands()) {
805 Operands.push_back(((SC *)this)->visit(Op));
806 Changed |= Op != Operands.back();
807 }
808 return !Changed ? Expr : SE.getMulExpr(Operands);
809 }
810
visitUDivExpr(const SCEVUDivExpr * Expr)811 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
812 auto *LHS = ((SC *)this)->visit(Expr->getLHS());
813 auto *RHS = ((SC *)this)->visit(Expr->getRHS());
814 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
815 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
816 }
817
visitAddRecExpr(const SCEVAddRecExpr * Expr)818 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
819 SmallVector<const SCEV *, 2> Operands;
820 bool Changed = false;
821 for (const auto *Op : Expr->operands()) {
822 Operands.push_back(((SC *)this)->visit(Op));
823 Changed |= Op != Operands.back();
824 }
825 return !Changed ? Expr
826 : SE.getAddRecExpr(Operands, Expr->getLoop(),
827 Expr->getNoWrapFlags());
828 }
829
visitSMaxExpr(const SCEVSMaxExpr * Expr)830 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
831 SmallVector<const SCEV *, 2> Operands;
832 bool Changed = false;
833 for (const auto *Op : Expr->operands()) {
834 Operands.push_back(((SC *)this)->visit(Op));
835 Changed |= Op != Operands.back();
836 }
837 return !Changed ? Expr : SE.getSMaxExpr(Operands);
838 }
839
visitUMaxExpr(const SCEVUMaxExpr * Expr)840 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
841 SmallVector<const SCEV *, 2> Operands;
842 bool Changed = false;
843 for (const auto *Op : Expr->operands()) {
844 Operands.push_back(((SC *)this)->visit(Op));
845 Changed |= Op != Operands.back();
846 }
847 return !Changed ? Expr : SE.getUMaxExpr(Operands);
848 }
849
visitSMinExpr(const SCEVSMinExpr * Expr)850 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
851 SmallVector<const SCEV *, 2> Operands;
852 bool Changed = false;
853 for (const auto *Op : Expr->operands()) {
854 Operands.push_back(((SC *)this)->visit(Op));
855 Changed |= Op != Operands.back();
856 }
857 return !Changed ? Expr : SE.getSMinExpr(Operands);
858 }
859
visitUMinExpr(const SCEVUMinExpr * Expr)860 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
861 SmallVector<const SCEV *, 2> Operands;
862 bool Changed = false;
863 for (const auto *Op : Expr->operands()) {
864 Operands.push_back(((SC *)this)->visit(Op));
865 Changed |= Op != Operands.back();
866 }
867 return !Changed ? Expr : SE.getUMinExpr(Operands);
868 }
869
visitSequentialUMinExpr(const SCEVSequentialUMinExpr * Expr)870 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
871 SmallVector<const SCEV *, 2> Operands;
872 bool Changed = false;
873 for (const auto *Op : Expr->operands()) {
874 Operands.push_back(((SC *)this)->visit(Op));
875 Changed |= Op != Operands.back();
876 }
877 return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
878 }
879
visitUnknown(const SCEVUnknown * Expr)880 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
881
visitCouldNotCompute(const SCEVCouldNotCompute * Expr)882 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
883 return Expr;
884 }
885 };
886
887 using ValueToValueMap = DenseMap<const Value *, Value *>;
888 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
889
890 /// The SCEVParameterRewriter takes a scalar evolution expression and updates
891 /// the SCEVUnknown components following the Map (Value -> SCEV).
892 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
893 public:
rewrite(const SCEV * Scev,ScalarEvolution & SE,ValueToSCEVMapTy & Map)894 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
895 ValueToSCEVMapTy &Map) {
896 SCEVParameterRewriter Rewriter(SE, Map);
897 return Rewriter.visit(Scev);
898 }
899
SCEVParameterRewriter(ScalarEvolution & SE,ValueToSCEVMapTy & M)900 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
901 : SCEVRewriteVisitor(SE), Map(M) {}
902
visitUnknown(const SCEVUnknown * Expr)903 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
904 auto I = Map.find(Expr->getValue());
905 if (I == Map.end())
906 return Expr;
907 return I->second;
908 }
909
910 private:
911 ValueToSCEVMapTy ⤅
912 };
913
914 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
915
916 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
917 /// the Map (Loop -> SCEV) to all AddRecExprs.
918 class SCEVLoopAddRecRewriter
919 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
920 public:
SCEVLoopAddRecRewriter(ScalarEvolution & SE,LoopToScevMapT & M)921 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
922 : SCEVRewriteVisitor(SE), Map(M) {}
923
rewrite(const SCEV * Scev,LoopToScevMapT & Map,ScalarEvolution & SE)924 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
925 ScalarEvolution &SE) {
926 SCEVLoopAddRecRewriter Rewriter(SE, Map);
927 return Rewriter.visit(Scev);
928 }
929
visitAddRecExpr(const SCEVAddRecExpr * Expr)930 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
931 SmallVector<const SCEV *, 2> Operands;
932 for (const SCEV *Op : Expr->operands())
933 Operands.push_back(visit(Op));
934
935 const Loop *L = Expr->getLoop();
936 if (0 == Map.count(L))
937 return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
938
939 return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
940 }
941
942 private:
943 LoopToScevMapT ⤅
944 };
945
946 } // end namespace llvm
947
948 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
949