1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/iterator_range.h"
20 #include "llvm/Analysis/ScalarEvolution.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/ValueHandle.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include <cassert>
26 #include <cstddef>
27 
28 namespace llvm {
29 
30 class APInt;
31 class Constant;
32 class ConstantInt;
33 class ConstantRange;
34 class Loop;
35 class Type;
36 class Value;
37 
38 enum SCEVTypes : unsigned short {
39   // These should be ordered in terms of increasing complexity to make the
40   // folders simpler.
41   scConstant,
42   scTruncate,
43   scZeroExtend,
44   scSignExtend,
45   scAddExpr,
46   scMulExpr,
47   scUDivExpr,
48   scAddRecExpr,
49   scUMaxExpr,
50   scSMaxExpr,
51   scUMinExpr,
52   scSMinExpr,
53   scSequentialUMinExpr,
54   scPtrToInt,
55   scUnknown,
56   scCouldNotCompute
57 };
58 
59 /// This class represents a constant integer value.
60 class SCEVConstant : public SCEV {
61   friend class ScalarEvolution;
62 
63   ConstantInt *V;
64 
65   SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
66       : SCEV(ID, scConstant, 1), V(v) {}
67 
68 public:
69   ConstantInt *getValue() const { return V; }
70   const APInt &getAPInt() const { return getValue()->getValue(); }
71 
72   Type *getType() const { return V->getType(); }
73 
74   /// Methods for support type inquiry through isa, cast, and dyn_cast:
75   static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
76 };
77 
78 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
79   APInt Size(16, 1);
80   for (auto *Arg : Args)
81     Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
82   return (unsigned short)Size.getZExtValue();
83 }
84 
85 /// This is the base class for unary cast operator classes.
86 class SCEVCastExpr : public SCEV {
87 protected:
88   std::array<const SCEV *, 1> Operands;
89   Type *Ty;
90 
91   SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
92                Type *ty);
93 
94 public:
95   const SCEV *getOperand() const { return Operands[0]; }
96   const SCEV *getOperand(unsigned i) const {
97     assert(i == 0 && "Operand index out of range!");
98     return Operands[0];
99   }
100   using op_iterator = std::array<const SCEV *, 1>::const_iterator;
101   using op_range = iterator_range<op_iterator>;
102 
103   op_range operands() const {
104     return make_range(Operands.begin(), Operands.end());
105   }
106   size_t getNumOperands() const { return 1; }
107   Type *getType() const { return Ty; }
108 
109   /// Methods for support type inquiry through isa, cast, and dyn_cast:
110   static bool classof(const SCEV *S) {
111     return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
112            S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
113   }
114 };
115 
116 /// This class represents a cast from a pointer to a pointer-sized integer
117 /// value.
118 class SCEVPtrToIntExpr : public SCEVCastExpr {
119   friend class ScalarEvolution;
120 
121   SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
122 
123 public:
124   /// Methods for support type inquiry through isa, cast, and dyn_cast:
125   static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
126 };
127 
128 /// This is the base class for unary integral cast operator classes.
129 class SCEVIntegralCastExpr : public SCEVCastExpr {
130 protected:
131   SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
132                        const SCEV *op, Type *ty);
133 
134 public:
135   /// Methods for support type inquiry through isa, cast, and dyn_cast:
136   static bool classof(const SCEV *S) {
137     return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
138            S->getSCEVType() == scSignExtend;
139   }
140 };
141 
142 /// This class represents a truncation of an integer value to a
143 /// smaller integer value.
144 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
145   friend class ScalarEvolution;
146 
147   SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
148 
149 public:
150   /// Methods for support type inquiry through isa, cast, and dyn_cast:
151   static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
152 };
153 
154 /// This class represents a zero extension of a small integer value
155 /// to a larger integer value.
156 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
157   friend class ScalarEvolution;
158 
159   SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
160 
161 public:
162   /// Methods for support type inquiry through isa, cast, and dyn_cast:
163   static bool classof(const SCEV *S) {
164     return S->getSCEVType() == scZeroExtend;
165   }
166 };
167 
168 /// This class represents a sign extension of a small integer value
169 /// to a larger integer value.
170 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
171   friend class ScalarEvolution;
172 
173   SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
174 
175 public:
176   /// Methods for support type inquiry through isa, cast, and dyn_cast:
177   static bool classof(const SCEV *S) {
178     return S->getSCEVType() == scSignExtend;
179   }
180 };
181 
182 /// This node is a base class providing common functionality for
183 /// n'ary operators.
184 class SCEVNAryExpr : public SCEV {
185 protected:
186   // Since SCEVs are immutable, ScalarEvolution allocates operand
187   // arrays with its SCEVAllocator, so this class just needs a simple
188   // pointer rather than a more elaborate vector-like data structure.
189   // This also avoids the need for a non-trivial destructor.
190   const SCEV *const *Operands;
191   size_t NumOperands;
192 
193   SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
194                const SCEV *const *O, size_t N)
195       : SCEV(ID, T, computeExpressionSize(makeArrayRef(O, N))), Operands(O),
196         NumOperands(N) {}
197 
198 public:
199   size_t getNumOperands() const { return NumOperands; }
200 
201   const SCEV *getOperand(unsigned i) const {
202     assert(i < NumOperands && "Operand index out of range!");
203     return Operands[i];
204   }
205 
206   using op_iterator = const SCEV *const *;
207   using op_range = iterator_range<op_iterator>;
208 
209   op_iterator op_begin() const { return Operands; }
210   op_iterator op_end() const { return Operands + NumOperands; }
211   op_range operands() const { return make_range(op_begin(), op_end()); }
212 
213   NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
214     return (NoWrapFlags)(SubclassData & Mask);
215   }
216 
217   bool hasNoUnsignedWrap() const {
218     return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
219   }
220 
221   bool hasNoSignedWrap() const {
222     return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
223   }
224 
225   bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
226 
227   /// Methods for support type inquiry through isa, cast, and dyn_cast:
228   static bool classof(const SCEV *S) {
229     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
230            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
231            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
232            S->getSCEVType() == scSequentialUMinExpr ||
233            S->getSCEVType() == scAddRecExpr;
234   }
235 };
236 
237 /// This node is the base class for n'ary commutative operators.
238 class SCEVCommutativeExpr : public SCEVNAryExpr {
239 protected:
240   SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
241                       const SCEV *const *O, size_t N)
242       : SCEVNAryExpr(ID, T, O, N) {}
243 
244 public:
245   /// Methods for support type inquiry through isa, cast, and dyn_cast:
246   static bool classof(const SCEV *S) {
247     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
248            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
249            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
250   }
251 
252   /// Set flags for a non-recurrence without clearing previously set flags.
253   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
254 };
255 
256 /// This node represents an addition of some number of SCEVs.
257 class SCEVAddExpr : public SCEVCommutativeExpr {
258   friend class ScalarEvolution;
259 
260   Type *Ty;
261 
262   SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
263       : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
264     auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
265       return Op->getType()->isPointerTy();
266     });
267     if (FirstPointerTypedOp != operands().end())
268       Ty = (*FirstPointerTypedOp)->getType();
269     else
270       Ty = getOperand(0)->getType();
271   }
272 
273 public:
274   Type *getType() const { return Ty; }
275 
276   /// Methods for support type inquiry through isa, cast, and dyn_cast:
277   static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
278 };
279 
280 /// This node represents multiplication of some number of SCEVs.
281 class SCEVMulExpr : public SCEVCommutativeExpr {
282   friend class ScalarEvolution;
283 
284   SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
285       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
286 
287 public:
288   Type *getType() const { return getOperand(0)->getType(); }
289 
290   /// Methods for support type inquiry through isa, cast, and dyn_cast:
291   static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
292 };
293 
294 /// This class represents a binary unsigned division operation.
295 class SCEVUDivExpr : public SCEV {
296   friend class ScalarEvolution;
297 
298   std::array<const SCEV *, 2> Operands;
299 
300   SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
301       : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
302     Operands[0] = lhs;
303     Operands[1] = rhs;
304   }
305 
306 public:
307   const SCEV *getLHS() const { return Operands[0]; }
308   const SCEV *getRHS() const { return Operands[1]; }
309   size_t getNumOperands() const { return 2; }
310   const SCEV *getOperand(unsigned i) const {
311     assert((i == 0 || i == 1) && "Operand index out of range!");
312     return i == 0 ? getLHS() : getRHS();
313   }
314 
315   using op_iterator = std::array<const SCEV *, 2>::const_iterator;
316   using op_range = iterator_range<op_iterator>;
317   op_range operands() const {
318     return make_range(Operands.begin(), Operands.end());
319   }
320 
321   Type *getType() const {
322     // In most cases the types of LHS and RHS will be the same, but in some
323     // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
324     // depend on the type for correctness, but handling types carefully can
325     // avoid extra casts in the SCEVExpander. The LHS is more likely to be
326     // a pointer type than the RHS, so use the RHS' type here.
327     return getRHS()->getType();
328   }
329 
330   /// Methods for support type inquiry through isa, cast, and dyn_cast:
331   static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
332 };
333 
334 /// This node represents a polynomial recurrence on the trip count
335 /// of the specified loop.  This is the primary focus of the
336 /// ScalarEvolution framework; all the other SCEV subclasses are
337 /// mostly just supporting infrastructure to allow SCEVAddRecExpr
338 /// expressions to be created and analyzed.
339 ///
340 /// All operands of an AddRec are required to be loop invariant.
341 ///
342 class SCEVAddRecExpr : public SCEVNAryExpr {
343   friend class ScalarEvolution;
344 
345   const Loop *L;
346 
347   SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
348                  const Loop *l)
349       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
350 
351 public:
352   Type *getType() const { return getStart()->getType(); }
353   const SCEV *getStart() const { return Operands[0]; }
354   const Loop *getLoop() const { return L; }
355 
356   /// Constructs and returns the recurrence indicating how much this
357   /// expression steps by.  If this is a polynomial of degree N, it
358   /// returns a chrec of degree N-1.  We cannot determine whether
359   /// the step recurrence has self-wraparound.
360   const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
361     if (isAffine())
362       return getOperand(1);
363     return SE.getAddRecExpr(
364         SmallVector<const SCEV *, 3>(op_begin() + 1, op_end()), getLoop(),
365         FlagAnyWrap);
366   }
367 
368   /// Return true if this represents an expression A + B*x where A
369   /// and B are loop invariant values.
370   bool isAffine() const {
371     // We know that the start value is invariant.  This expression is thus
372     // affine iff the step is also invariant.
373     return getNumOperands() == 2;
374   }
375 
376   /// Return true if this represents an expression A + B*x + C*x^2
377   /// where A, B and C are loop invariant values.  This corresponds
378   /// to an addrec of the form {L,+,M,+,N}
379   bool isQuadratic() const { return getNumOperands() == 3; }
380 
381   /// Set flags for a recurrence without clearing any previously set flags.
382   /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
383   /// to make it easier to propagate flags.
384   void setNoWrapFlags(NoWrapFlags Flags) {
385     if (Flags & (FlagNUW | FlagNSW))
386       Flags = ScalarEvolution::setFlags(Flags, FlagNW);
387     SubclassData |= Flags;
388   }
389 
390   /// Return the value of this chain of recurrences at the specified
391   /// iteration number.
392   const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
393 
394   /// Return the value of this chain of recurrences at the specified iteration
395   /// number. Takes an explicit list of operands to represent an AddRec.
396   static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
397                                          const SCEV *It, ScalarEvolution &SE);
398 
399   /// Return the number of iterations of this loop that produce
400   /// values in the specified constant range.  Another way of
401   /// looking at this is that it returns the first iteration number
402   /// where the value is not in the condition, thus computing the
403   /// exit count.  If the iteration count can't be computed, an
404   /// instance of SCEVCouldNotCompute is returned.
405   const SCEV *getNumIterationsInRange(const ConstantRange &Range,
406                                       ScalarEvolution &SE) const;
407 
408   /// Return an expression representing the value of this expression
409   /// one iteration of the loop ahead.
410   const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
411 
412   /// Methods for support type inquiry through isa, cast, and dyn_cast:
413   static bool classof(const SCEV *S) {
414     return S->getSCEVType() == scAddRecExpr;
415   }
416 };
417 
418 /// This node is the base class min/max selections.
419 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
420   friend class ScalarEvolution;
421 
422   static bool isMinMaxType(enum SCEVTypes T) {
423     return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
424            T == scUMinExpr;
425   }
426 
427 protected:
428   /// Note: Constructing subclasses via this constructor is allowed
429   SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
430                  const SCEV *const *O, size_t N)
431       : SCEVCommutativeExpr(ID, T, O, N) {
432     assert(isMinMaxType(T));
433     // Min and max never overflow
434     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
435   }
436 
437 public:
438   Type *getType() const { return getOperand(0)->getType(); }
439 
440   static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
441 
442   static enum SCEVTypes negate(enum SCEVTypes T) {
443     switch (T) {
444     case scSMaxExpr:
445       return scSMinExpr;
446     case scSMinExpr:
447       return scSMaxExpr;
448     case scUMaxExpr:
449       return scUMinExpr;
450     case scUMinExpr:
451       return scUMaxExpr;
452     default:
453       llvm_unreachable("Not a min or max SCEV type!");
454     }
455   }
456 };
457 
458 /// This class represents a signed maximum selection.
459 class SCEVSMaxExpr : public SCEVMinMaxExpr {
460   friend class ScalarEvolution;
461 
462   SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
463       : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
464 
465 public:
466   /// Methods for support type inquiry through isa, cast, and dyn_cast:
467   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
468 };
469 
470 /// This class represents an unsigned maximum selection.
471 class SCEVUMaxExpr : public SCEVMinMaxExpr {
472   friend class ScalarEvolution;
473 
474   SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
475       : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
476 
477 public:
478   /// Methods for support type inquiry through isa, cast, and dyn_cast:
479   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
480 };
481 
482 /// This class represents a signed minimum selection.
483 class SCEVSMinExpr : public SCEVMinMaxExpr {
484   friend class ScalarEvolution;
485 
486   SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
487       : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
488 
489 public:
490   /// Methods for support type inquiry through isa, cast, and dyn_cast:
491   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
492 };
493 
494 /// This class represents an unsigned minimum selection.
495 class SCEVUMinExpr : public SCEVMinMaxExpr {
496   friend class ScalarEvolution;
497 
498   SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
499       : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
500 
501 public:
502   /// Methods for support type inquiry through isa, cast, and dyn_cast:
503   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
504 };
505 
506 /// This node is the base class for sequential/in-order min/max selections.
507 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they
508 /// are early-returning upon reaching saturation point.
509 /// I.e. given `0 umin_seq poison`, the result will be `0`,
510 /// while the result of `0 umin poison` is `poison`.
511 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
512   friend class ScalarEvolution;
513 
514   static bool isSequentialMinMaxType(enum SCEVTypes T) {
515     return T == scSequentialUMinExpr;
516   }
517 
518   /// Set flags for a non-recurrence without clearing previously set flags.
519   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
520 
521 protected:
522   /// Note: Constructing subclasses via this constructor is allowed
523   SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
524                            const SCEV *const *O, size_t N)
525       : SCEVNAryExpr(ID, T, O, N) {
526     assert(isSequentialMinMaxType(T));
527     // Min and max never overflow
528     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
529   }
530 
531 public:
532   Type *getType() const { return getOperand(0)->getType(); }
533 
534   static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
535     assert(isSequentialMinMaxType(Ty));
536     switch (Ty) {
537     case scSequentialUMinExpr:
538       return scUMinExpr;
539     default:
540       llvm_unreachable("Not a sequential min/max type.");
541     }
542   }
543 
544   SCEVTypes getEquivalentNonSequentialSCEVType() const {
545     return getEquivalentNonSequentialSCEVType(getSCEVType());
546   }
547 
548   static bool classof(const SCEV *S) {
549     return isSequentialMinMaxType(S->getSCEVType());
550   }
551 };
552 
553 /// This class represents a sequential/in-order unsigned minimum selection.
554 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
555   friend class ScalarEvolution;
556 
557   SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
558                          size_t N)
559       : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
560 
561 public:
562   /// Methods for support type inquiry through isa, cast, and dyn_cast:
563   static bool classof(const SCEV *S) {
564     return S->getSCEVType() == scSequentialUMinExpr;
565   }
566 };
567 
568 /// This means that we are dealing with an entirely unknown SCEV
569 /// value, and only represent it as its LLVM Value.  This is the
570 /// "bottom" value for the analysis.
571 class SCEVUnknown final : public SCEV, private CallbackVH {
572   friend class ScalarEvolution;
573 
574   /// The parent ScalarEvolution value. This is used to update the
575   /// parent's maps when the value associated with a SCEVUnknown is
576   /// deleted or RAUW'd.
577   ScalarEvolution *SE;
578 
579   /// The next pointer in the linked list of all SCEVUnknown
580   /// instances owned by a ScalarEvolution.
581   SCEVUnknown *Next;
582 
583   SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
584               SCEVUnknown *next)
585       : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
586 
587   // Implement CallbackVH.
588   void deleted() override;
589   void allUsesReplacedWith(Value *New) override;
590 
591 public:
592   Value *getValue() const { return getValPtr(); }
593 
594   /// @{
595   /// Test whether this is a special constant representing a type
596   /// size, alignment, or field offset in a target-independent
597   /// manner, and hasn't happened to have been folded with other
598   /// operations into something unrecognizable. This is mainly only
599   /// useful for pretty-printing and other situations where it isn't
600   /// absolutely required for these to succeed.
601   bool isSizeOf(Type *&AllocTy) const;
602   bool isAlignOf(Type *&AllocTy) const;
603   bool isOffsetOf(Type *&STy, Constant *&FieldNo) const;
604   /// @}
605 
606   Type *getType() const { return getValPtr()->getType(); }
607 
608   /// Methods for support type inquiry through isa, cast, and dyn_cast:
609   static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
610 };
611 
612 /// This class defines a simple visitor class that may be used for
613 /// various SCEV analysis purposes.
614 template <typename SC, typename RetVal = void> struct SCEVVisitor {
615   RetVal visit(const SCEV *S) {
616     switch (S->getSCEVType()) {
617     case scConstant:
618       return ((SC *)this)->visitConstant((const SCEVConstant *)S);
619     case scPtrToInt:
620       return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
621     case scTruncate:
622       return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
623     case scZeroExtend:
624       return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
625     case scSignExtend:
626       return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
627     case scAddExpr:
628       return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
629     case scMulExpr:
630       return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
631     case scUDivExpr:
632       return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
633     case scAddRecExpr:
634       return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
635     case scSMaxExpr:
636       return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
637     case scUMaxExpr:
638       return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
639     case scSMinExpr:
640       return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
641     case scUMinExpr:
642       return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
643     case scSequentialUMinExpr:
644       return ((SC *)this)
645           ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
646     case scUnknown:
647       return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
648     case scCouldNotCompute:
649       return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
650     }
651     llvm_unreachable("Unknown SCEV kind!");
652   }
653 
654   RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
655     llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
656   }
657 };
658 
659 /// Visit all nodes in the expression tree using worklist traversal.
660 ///
661 /// Visitor implements:
662 ///   // return true to follow this node.
663 ///   bool follow(const SCEV *S);
664 ///   // return true to terminate the search.
665 ///   bool isDone();
666 template <typename SV> class SCEVTraversal {
667   SV &Visitor;
668   SmallVector<const SCEV *, 8> Worklist;
669   SmallPtrSet<const SCEV *, 8> Visited;
670 
671   void push(const SCEV *S) {
672     if (Visited.insert(S).second && Visitor.follow(S))
673       Worklist.push_back(S);
674   }
675 
676 public:
677   SCEVTraversal(SV &V) : Visitor(V) {}
678 
679   void visitAll(const SCEV *Root) {
680     push(Root);
681     while (!Worklist.empty() && !Visitor.isDone()) {
682       const SCEV *S = Worklist.pop_back_val();
683 
684       switch (S->getSCEVType()) {
685       case scConstant:
686       case scUnknown:
687         continue;
688       case scPtrToInt:
689       case scTruncate:
690       case scZeroExtend:
691       case scSignExtend:
692         push(cast<SCEVCastExpr>(S)->getOperand());
693         continue;
694       case scAddExpr:
695       case scMulExpr:
696       case scSMaxExpr:
697       case scUMaxExpr:
698       case scSMinExpr:
699       case scUMinExpr:
700       case scSequentialUMinExpr:
701       case scAddRecExpr:
702         for (const auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
703           push(Op);
704           if (Visitor.isDone())
705             break;
706         }
707         continue;
708       case scUDivExpr: {
709         const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
710         push(UDiv->getLHS());
711         push(UDiv->getRHS());
712         continue;
713       }
714       case scCouldNotCompute:
715         llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
716       }
717       llvm_unreachable("Unknown SCEV kind!");
718     }
719   }
720 };
721 
722 /// Use SCEVTraversal to visit all nodes in the given expression tree.
723 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
724   SCEVTraversal<SV> T(Visitor);
725   T.visitAll(Root);
726 }
727 
728 /// Return true if any node in \p Root satisfies the predicate \p Pred.
729 template <typename PredTy>
730 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
731   struct FindClosure {
732     bool Found = false;
733     PredTy Pred;
734 
735     FindClosure(PredTy Pred) : Pred(Pred) {}
736 
737     bool follow(const SCEV *S) {
738       if (!Pred(S))
739         return true;
740 
741       Found = true;
742       return false;
743     }
744 
745     bool isDone() const { return Found; }
746   };
747 
748   FindClosure FC(Pred);
749   visitAll(Root, FC);
750   return FC.Found;
751 }
752 
753 /// This visitor recursively visits a SCEV expression and re-writes it.
754 /// The result from each visit is cached, so it will return the same
755 /// SCEV for the same input.
756 template <typename SC>
757 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
758 protected:
759   ScalarEvolution &SE;
760   // Memoize the result of each visit so that we only compute once for
761   // the same input SCEV. This is to avoid redundant computations when
762   // a SCEV is referenced by multiple SCEVs. Without memoization, this
763   // visit algorithm would have exponential time complexity in the worst
764   // case, causing the compiler to hang on certain tests.
765   DenseMap<const SCEV *, const SCEV *> RewriteResults;
766 
767 public:
768   SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
769 
770   const SCEV *visit(const SCEV *S) {
771     auto It = RewriteResults.find(S);
772     if (It != RewriteResults.end())
773       return It->second;
774     auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
775     auto Result = RewriteResults.try_emplace(S, Visited);
776     assert(Result.second && "Should insert a new entry");
777     return Result.first->second;
778   }
779 
780   const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
781 
782   const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
783     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
784     return Operand == Expr->getOperand()
785                ? Expr
786                : SE.getPtrToIntExpr(Operand, Expr->getType());
787   }
788 
789   const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
790     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
791     return Operand == Expr->getOperand()
792                ? Expr
793                : SE.getTruncateExpr(Operand, Expr->getType());
794   }
795 
796   const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
797     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
798     return Operand == Expr->getOperand()
799                ? Expr
800                : SE.getZeroExtendExpr(Operand, Expr->getType());
801   }
802 
803   const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
804     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
805     return Operand == Expr->getOperand()
806                ? Expr
807                : SE.getSignExtendExpr(Operand, Expr->getType());
808   }
809 
810   const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
811     SmallVector<const SCEV *, 2> Operands;
812     bool Changed = false;
813     for (auto *Op : Expr->operands()) {
814       Operands.push_back(((SC *)this)->visit(Op));
815       Changed |= Op != Operands.back();
816     }
817     return !Changed ? Expr : SE.getAddExpr(Operands);
818   }
819 
820   const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
821     SmallVector<const SCEV *, 2> Operands;
822     bool Changed = false;
823     for (auto *Op : Expr->operands()) {
824       Operands.push_back(((SC *)this)->visit(Op));
825       Changed |= Op != Operands.back();
826     }
827     return !Changed ? Expr : SE.getMulExpr(Operands);
828   }
829 
830   const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
831     auto *LHS = ((SC *)this)->visit(Expr->getLHS());
832     auto *RHS = ((SC *)this)->visit(Expr->getRHS());
833     bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
834     return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
835   }
836 
837   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
838     SmallVector<const SCEV *, 2> Operands;
839     bool Changed = false;
840     for (auto *Op : Expr->operands()) {
841       Operands.push_back(((SC *)this)->visit(Op));
842       Changed |= Op != Operands.back();
843     }
844     return !Changed ? Expr
845                     : SE.getAddRecExpr(Operands, Expr->getLoop(),
846                                        Expr->getNoWrapFlags());
847   }
848 
849   const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
850     SmallVector<const SCEV *, 2> Operands;
851     bool Changed = false;
852     for (auto *Op : Expr->operands()) {
853       Operands.push_back(((SC *)this)->visit(Op));
854       Changed |= Op != Operands.back();
855     }
856     return !Changed ? Expr : SE.getSMaxExpr(Operands);
857   }
858 
859   const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
860     SmallVector<const SCEV *, 2> Operands;
861     bool Changed = false;
862     for (auto *Op : Expr->operands()) {
863       Operands.push_back(((SC *)this)->visit(Op));
864       Changed |= Op != Operands.back();
865     }
866     return !Changed ? Expr : SE.getUMaxExpr(Operands);
867   }
868 
869   const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
870     SmallVector<const SCEV *, 2> Operands;
871     bool Changed = false;
872     for (auto *Op : Expr->operands()) {
873       Operands.push_back(((SC *)this)->visit(Op));
874       Changed |= Op != Operands.back();
875     }
876     return !Changed ? Expr : SE.getSMinExpr(Operands);
877   }
878 
879   const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
880     SmallVector<const SCEV *, 2> Operands;
881     bool Changed = false;
882     for (auto *Op : Expr->operands()) {
883       Operands.push_back(((SC *)this)->visit(Op));
884       Changed |= Op != Operands.back();
885     }
886     return !Changed ? Expr : SE.getUMinExpr(Operands);
887   }
888 
889   const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
890     SmallVector<const SCEV *, 2> Operands;
891     bool Changed = false;
892     for (auto *Op : Expr->operands()) {
893       Operands.push_back(((SC *)this)->visit(Op));
894       Changed |= Op != Operands.back();
895     }
896     return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
897   }
898 
899   const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
900 
901   const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
902     return Expr;
903   }
904 };
905 
906 using ValueToValueMap = DenseMap<const Value *, Value *>;
907 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
908 
909 /// The SCEVParameterRewriter takes a scalar evolution expression and updates
910 /// the SCEVUnknown components following the Map (Value -> SCEV).
911 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
912 public:
913   static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
914                              ValueToSCEVMapTy &Map) {
915     SCEVParameterRewriter Rewriter(SE, Map);
916     return Rewriter.visit(Scev);
917   }
918 
919   SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
920       : SCEVRewriteVisitor(SE), Map(M) {}
921 
922   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
923     auto I = Map.find(Expr->getValue());
924     if (I == Map.end())
925       return Expr;
926     return I->second;
927   }
928 
929 private:
930   ValueToSCEVMapTy &Map;
931 };
932 
933 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
934 
935 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
936 /// the Map (Loop -> SCEV) to all AddRecExprs.
937 class SCEVLoopAddRecRewriter
938     : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
939 public:
940   SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
941       : SCEVRewriteVisitor(SE), Map(M) {}
942 
943   static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
944                              ScalarEvolution &SE) {
945     SCEVLoopAddRecRewriter Rewriter(SE, Map);
946     return Rewriter.visit(Scev);
947   }
948 
949   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
950     SmallVector<const SCEV *, 2> Operands;
951     for (const SCEV *Op : Expr->operands())
952       Operands.push_back(visit(Op));
953 
954     const Loop *L = Expr->getLoop();
955     if (0 == Map.count(L))
956       return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
957 
958     return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
959   }
960 
961 private:
962   LoopToScevMapT &Map;
963 };
964 
965 } // end namespace llvm
966 
967 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
968