1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/FoldingSet.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/ScalarEvolution.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/IR/ValueHandle.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include <cassert>
28 #include <cstddef>
29 
30 namespace llvm {
31 
32 class APInt;
33 class Constant;
34 class ConstantRange;
35 class Loop;
36 class Type;
37 
38 enum SCEVTypes : unsigned short {
39   // These should be ordered in terms of increasing complexity to make the
40   // folders simpler.
41   scConstant,
42   scTruncate,
43   scZeroExtend,
44   scSignExtend,
45   scAddExpr,
46   scMulExpr,
47   scUDivExpr,
48   scAddRecExpr,
49   scUMaxExpr,
50   scSMaxExpr,
51   scUMinExpr,
52   scSMinExpr,
53   scSequentialUMinExpr,
54   scPtrToInt,
55   scUnknown,
56   scCouldNotCompute
57 };
58 
59 /// This class represents a constant integer value.
60 class SCEVConstant : public SCEV {
61   friend class ScalarEvolution;
62 
63   ConstantInt *V;
64 
65   SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
66       : SCEV(ID, scConstant, 1), V(v) {}
67 
68 public:
69   ConstantInt *getValue() const { return V; }
70   const APInt &getAPInt() const { return getValue()->getValue(); }
71 
72   Type *getType() const { return V->getType(); }
73 
74   /// Methods for support type inquiry through isa, cast, and dyn_cast:
75   static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
76 };
77 
78 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
79   APInt Size(16, 1);
80   for (auto *Arg : Args)
81     Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
82   return (unsigned short)Size.getZExtValue();
83 }
84 
85 /// This is the base class for unary cast operator classes.
86 class SCEVCastExpr : public SCEV {
87 protected:
88   std::array<const SCEV *, 1> Operands;
89   Type *Ty;
90 
91   SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
92                Type *ty);
93 
94 public:
95   const SCEV *getOperand() const { return Operands[0]; }
96   const SCEV *getOperand(unsigned i) const {
97     assert(i == 0 && "Operand index out of range!");
98     return Operands[0];
99   }
100   using op_iterator = std::array<const SCEV *, 1>::const_iterator;
101   using op_range = iterator_range<op_iterator>;
102 
103   op_range operands() const {
104     return make_range(Operands.begin(), Operands.end());
105   }
106   size_t getNumOperands() const { return 1; }
107   Type *getType() const { return Ty; }
108 
109   /// Methods for support type inquiry through isa, cast, and dyn_cast:
110   static bool classof(const SCEV *S) {
111     return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
112            S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
113   }
114 };
115 
116 /// This class represents a cast from a pointer to a pointer-sized integer
117 /// value.
118 class SCEVPtrToIntExpr : public SCEVCastExpr {
119   friend class ScalarEvolution;
120 
121   SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
122 
123 public:
124   /// Methods for support type inquiry through isa, cast, and dyn_cast:
125   static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
126 };
127 
128 /// This is the base class for unary integral cast operator classes.
129 class SCEVIntegralCastExpr : public SCEVCastExpr {
130 protected:
131   SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
132                        const SCEV *op, Type *ty);
133 
134 public:
135   /// Methods for support type inquiry through isa, cast, and dyn_cast:
136   static bool classof(const SCEV *S) {
137     return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
138            S->getSCEVType() == scSignExtend;
139   }
140 };
141 
142 /// This class represents a truncation of an integer value to a
143 /// smaller integer value.
144 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
145   friend class ScalarEvolution;
146 
147   SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
148 
149 public:
150   /// Methods for support type inquiry through isa, cast, and dyn_cast:
151   static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
152 };
153 
154 /// This class represents a zero extension of a small integer value
155 /// to a larger integer value.
156 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
157   friend class ScalarEvolution;
158 
159   SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
160 
161 public:
162   /// Methods for support type inquiry through isa, cast, and dyn_cast:
163   static bool classof(const SCEV *S) {
164     return S->getSCEVType() == scZeroExtend;
165   }
166 };
167 
168 /// This class represents a sign extension of a small integer value
169 /// to a larger integer value.
170 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
171   friend class ScalarEvolution;
172 
173   SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
174 
175 public:
176   /// Methods for support type inquiry through isa, cast, and dyn_cast:
177   static bool classof(const SCEV *S) {
178     return S->getSCEVType() == scSignExtend;
179   }
180 };
181 
182 /// This node is a base class providing common functionality for
183 /// n'ary operators.
184 class SCEVNAryExpr : public SCEV {
185 protected:
186   // Since SCEVs are immutable, ScalarEvolution allocates operand
187   // arrays with its SCEVAllocator, so this class just needs a simple
188   // pointer rather than a more elaborate vector-like data structure.
189   // This also avoids the need for a non-trivial destructor.
190   const SCEV *const *Operands;
191   size_t NumOperands;
192 
193   SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
194                const SCEV *const *O, size_t N)
195       : SCEV(ID, T, computeExpressionSize(makeArrayRef(O, N))), Operands(O),
196         NumOperands(N) {}
197 
198 public:
199   size_t getNumOperands() const { return NumOperands; }
200 
201   const SCEV *getOperand(unsigned i) const {
202     assert(i < NumOperands && "Operand index out of range!");
203     return Operands[i];
204   }
205 
206   using op_iterator = const SCEV *const *;
207   using op_range = iterator_range<op_iterator>;
208 
209   op_iterator op_begin() const { return Operands; }
210   op_iterator op_end() const { return Operands + NumOperands; }
211   op_range operands() const { return make_range(op_begin(), op_end()); }
212 
213   NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
214     return (NoWrapFlags)(SubclassData & Mask);
215   }
216 
217   bool hasNoUnsignedWrap() const {
218     return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
219   }
220 
221   bool hasNoSignedWrap() const {
222     return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
223   }
224 
225   bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
226 
227   /// Methods for support type inquiry through isa, cast, and dyn_cast:
228   static bool classof(const SCEV *S) {
229     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
230            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
231            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
232            S->getSCEVType() == scSequentialUMinExpr ||
233            S->getSCEVType() == scAddRecExpr;
234   }
235 };
236 
237 /// This node is the base class for n'ary commutative operators.
238 class SCEVCommutativeExpr : public SCEVNAryExpr {
239 protected:
240   SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
241                       const SCEV *const *O, size_t N)
242       : SCEVNAryExpr(ID, T, O, N) {}
243 
244 public:
245   /// Methods for support type inquiry through isa, cast, and dyn_cast:
246   static bool classof(const SCEV *S) {
247     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
248            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
249            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
250   }
251 
252   /// Set flags for a non-recurrence without clearing previously set flags.
253   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
254 };
255 
256 /// This node represents an addition of some number of SCEVs.
257 class SCEVAddExpr : public SCEVCommutativeExpr {
258   friend class ScalarEvolution;
259 
260   Type *Ty;
261 
262   SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
263       : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
264     auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
265       return Op->getType()->isPointerTy();
266     });
267     if (FirstPointerTypedOp != operands().end())
268       Ty = (*FirstPointerTypedOp)->getType();
269     else
270       Ty = getOperand(0)->getType();
271   }
272 
273 public:
274   Type *getType() const { return Ty; }
275 
276   /// Methods for support type inquiry through isa, cast, and dyn_cast:
277   static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
278 };
279 
280 /// This node represents multiplication of some number of SCEVs.
281 class SCEVMulExpr : public SCEVCommutativeExpr {
282   friend class ScalarEvolution;
283 
284   SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
285       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
286 
287 public:
288   Type *getType() const { return getOperand(0)->getType(); }
289 
290   /// Methods for support type inquiry through isa, cast, and dyn_cast:
291   static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
292 };
293 
294 /// This class represents a binary unsigned division operation.
295 class SCEVUDivExpr : public SCEV {
296   friend class ScalarEvolution;
297 
298   std::array<const SCEV *, 2> Operands;
299 
300   SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
301       : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
302     Operands[0] = lhs;
303     Operands[1] = rhs;
304   }
305 
306 public:
307   const SCEV *getLHS() const { return Operands[0]; }
308   const SCEV *getRHS() const { return Operands[1]; }
309   size_t getNumOperands() const { return 2; }
310   const SCEV *getOperand(unsigned i) const {
311     assert((i == 0 || i == 1) && "Operand index out of range!");
312     return i == 0 ? getLHS() : getRHS();
313   }
314 
315   using op_iterator = std::array<const SCEV *, 2>::const_iterator;
316   using op_range = iterator_range<op_iterator>;
317   op_range operands() const {
318     return make_range(Operands.begin(), Operands.end());
319   }
320 
321   Type *getType() const {
322     // In most cases the types of LHS and RHS will be the same, but in some
323     // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
324     // depend on the type for correctness, but handling types carefully can
325     // avoid extra casts in the SCEVExpander. The LHS is more likely to be
326     // a pointer type than the RHS, so use the RHS' type here.
327     return getRHS()->getType();
328   }
329 
330   /// Methods for support type inquiry through isa, cast, and dyn_cast:
331   static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
332 };
333 
334 /// This node represents a polynomial recurrence on the trip count
335 /// of the specified loop.  This is the primary focus of the
336 /// ScalarEvolution framework; all the other SCEV subclasses are
337 /// mostly just supporting infrastructure to allow SCEVAddRecExpr
338 /// expressions to be created and analyzed.
339 ///
340 /// All operands of an AddRec are required to be loop invariant.
341 ///
342 class SCEVAddRecExpr : public SCEVNAryExpr {
343   friend class ScalarEvolution;
344 
345   const Loop *L;
346 
347   SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
348                  const Loop *l)
349       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
350 
351 public:
352   Type *getType() const { return getStart()->getType(); }
353   const SCEV *getStart() const { return Operands[0]; }
354   const Loop *getLoop() const { return L; }
355 
356   /// Constructs and returns the recurrence indicating how much this
357   /// expression steps by.  If this is a polynomial of degree N, it
358   /// returns a chrec of degree N-1.  We cannot determine whether
359   /// the step recurrence has self-wraparound.
360   const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
361     if (isAffine())
362       return getOperand(1);
363     return SE.getAddRecExpr(
364         SmallVector<const SCEV *, 3>(op_begin() + 1, op_end()), getLoop(),
365         FlagAnyWrap);
366   }
367 
368   /// Return true if this represents an expression A + B*x where A
369   /// and B are loop invariant values.
370   bool isAffine() const {
371     // We know that the start value is invariant.  This expression is thus
372     // affine iff the step is also invariant.
373     return getNumOperands() == 2;
374   }
375 
376   /// Return true if this represents an expression A + B*x + C*x^2
377   /// where A, B and C are loop invariant values.  This corresponds
378   /// to an addrec of the form {L,+,M,+,N}
379   bool isQuadratic() const { return getNumOperands() == 3; }
380 
381   /// Set flags for a recurrence without clearing any previously set flags.
382   /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
383   /// to make it easier to propagate flags.
384   void setNoWrapFlags(NoWrapFlags Flags) {
385     if (Flags & (FlagNUW | FlagNSW))
386       Flags = ScalarEvolution::setFlags(Flags, FlagNW);
387     SubclassData |= Flags;
388   }
389 
390   /// Return the value of this chain of recurrences at the specified
391   /// iteration number.
392   const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
393 
394   /// Return the value of this chain of recurrences at the specified iteration
395   /// number. Takes an explicit list of operands to represent an AddRec.
396   static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
397                                          const SCEV *It, ScalarEvolution &SE);
398 
399   /// Return the number of iterations of this loop that produce
400   /// values in the specified constant range.  Another way of
401   /// looking at this is that it returns the first iteration number
402   /// where the value is not in the condition, thus computing the
403   /// exit count.  If the iteration count can't be computed, an
404   /// instance of SCEVCouldNotCompute is returned.
405   const SCEV *getNumIterationsInRange(const ConstantRange &Range,
406                                       ScalarEvolution &SE) const;
407 
408   /// Return an expression representing the value of this expression
409   /// one iteration of the loop ahead.
410   const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
411 
412   /// Methods for support type inquiry through isa, cast, and dyn_cast:
413   static bool classof(const SCEV *S) {
414     return S->getSCEVType() == scAddRecExpr;
415   }
416 };
417 
418 /// This node is the base class min/max selections.
419 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
420   friend class ScalarEvolution;
421 
422   static bool isMinMaxType(enum SCEVTypes T) {
423     return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
424            T == scUMinExpr;
425   }
426 
427 protected:
428   /// Note: Constructing subclasses via this constructor is allowed
429   SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
430                  const SCEV *const *O, size_t N)
431       : SCEVCommutativeExpr(ID, T, O, N) {
432     assert(isMinMaxType(T));
433     // Min and max never overflow
434     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
435   }
436 
437 public:
438   Type *getType() const { return getOperand(0)->getType(); }
439 
440   static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
441 
442   static enum SCEVTypes negate(enum SCEVTypes T) {
443     switch (T) {
444     case scSMaxExpr:
445       return scSMinExpr;
446     case scSMinExpr:
447       return scSMaxExpr;
448     case scUMaxExpr:
449       return scUMinExpr;
450     case scUMinExpr:
451       return scUMaxExpr;
452     default:
453       llvm_unreachable("Not a min or max SCEV type!");
454     }
455   }
456 };
457 
458 /// This class represents a signed maximum selection.
459 class SCEVSMaxExpr : public SCEVMinMaxExpr {
460   friend class ScalarEvolution;
461 
462   SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
463       : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
464 
465 public:
466   /// Methods for support type inquiry through isa, cast, and dyn_cast:
467   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
468 };
469 
470 /// This class represents an unsigned maximum selection.
471 class SCEVUMaxExpr : public SCEVMinMaxExpr {
472   friend class ScalarEvolution;
473 
474   SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
475       : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
476 
477 public:
478   /// Methods for support type inquiry through isa, cast, and dyn_cast:
479   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
480 };
481 
482 /// This class represents a signed minimum selection.
483 class SCEVSMinExpr : public SCEVMinMaxExpr {
484   friend class ScalarEvolution;
485 
486   SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
487       : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
488 
489 public:
490   /// Methods for support type inquiry through isa, cast, and dyn_cast:
491   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
492 };
493 
494 /// This class represents an unsigned minimum selection.
495 class SCEVUMinExpr : public SCEVMinMaxExpr {
496   friend class ScalarEvolution;
497 
498   SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
499       : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
500 
501 public:
502   /// Methods for support type inquiry through isa, cast, and dyn_cast:
503   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
504 };
505 
506 /// This node is the base class for sequential/in-order min/max selections.
507 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they
508 /// are early-returning upon reaching saturation point.
509 /// I.e. given `0 umin_seq poison`, the result will be `0`,
510 /// while the result of `0 umin poison` is `poison`.
511 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
512   friend class ScalarEvolution;
513 
514   static bool isSequentialMinMaxType(enum SCEVTypes T) {
515     return T == scSequentialUMinExpr;
516   }
517 
518   /// Set flags for a non-recurrence without clearing previously set flags.
519   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
520 
521 protected:
522   /// Note: Constructing subclasses via this constructor is allowed
523   SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
524                            const SCEV *const *O, size_t N)
525       : SCEVNAryExpr(ID, T, O, N) {
526     assert(isSequentialMinMaxType(T));
527     // Min and max never overflow
528     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
529   }
530 
531 public:
532   Type *getType() const { return getOperand(0)->getType(); }
533 
534   static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
535     assert(isSequentialMinMaxType(Ty));
536     switch (Ty) {
537     case scSequentialUMinExpr:
538       return scUMinExpr;
539     default:
540       llvm_unreachable("Not a sequential min/max type.");
541     }
542   }
543 
544   SCEVTypes getEquivalentNonSequentialSCEVType() const {
545     return getEquivalentNonSequentialSCEVType(getSCEVType());
546   }
547 
548   static bool classof(const SCEV *S) {
549     return isSequentialMinMaxType(S->getSCEVType());
550   }
551 };
552 
553 /// This class represents a sequential/in-order unsigned minimum selection.
554 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
555   friend class ScalarEvolution;
556 
557   SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
558                          size_t N)
559       : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
560 
561 public:
562   /// Methods for support type inquiry through isa, cast, and dyn_cast:
563   static bool classof(const SCEV *S) {
564     return S->getSCEVType() == scSequentialUMinExpr;
565   }
566 };
567 
568 /// This means that we are dealing with an entirely unknown SCEV
569 /// value, and only represent it as its LLVM Value.  This is the
570 /// "bottom" value for the analysis.
571 class SCEVUnknown final : public SCEV, private CallbackVH {
572   friend class ScalarEvolution;
573 
574   /// The parent ScalarEvolution value. This is used to update the
575   /// parent's maps when the value associated with a SCEVUnknown is
576   /// deleted or RAUW'd.
577   ScalarEvolution *SE;
578 
579   /// The next pointer in the linked list of all SCEVUnknown
580   /// instances owned by a ScalarEvolution.
581   SCEVUnknown *Next;
582 
583   SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
584               SCEVUnknown *next)
585       : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
586 
587   // Implement CallbackVH.
588   void deleted() override;
589   void allUsesReplacedWith(Value *New) override;
590 
591 public:
592   Value *getValue() const { return getValPtr(); }
593 
594   /// @{
595   /// Test whether this is a special constant representing a type
596   /// size, alignment, or field offset in a target-independent
597   /// manner, and hasn't happened to have been folded with other
598   /// operations into something unrecognizable. This is mainly only
599   /// useful for pretty-printing and other situations where it isn't
600   /// absolutely required for these to succeed.
601   bool isSizeOf(Type *&AllocTy) const;
602   bool isAlignOf(Type *&AllocTy) const;
603   bool isOffsetOf(Type *&STy, Constant *&FieldNo) const;
604   /// @}
605 
606   Type *getType() const { return getValPtr()->getType(); }
607 
608   /// Methods for support type inquiry through isa, cast, and dyn_cast:
609   static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
610 };
611 
612 /// This class defines a simple visitor class that may be used for
613 /// various SCEV analysis purposes.
614 template <typename SC, typename RetVal = void> struct SCEVVisitor {
615   RetVal visit(const SCEV *S) {
616     switch (S->getSCEVType()) {
617     case scConstant:
618       return ((SC *)this)->visitConstant((const SCEVConstant *)S);
619     case scPtrToInt:
620       return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
621     case scTruncate:
622       return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
623     case scZeroExtend:
624       return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
625     case scSignExtend:
626       return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
627     case scAddExpr:
628       return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
629     case scMulExpr:
630       return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
631     case scUDivExpr:
632       return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
633     case scAddRecExpr:
634       return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
635     case scSMaxExpr:
636       return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
637     case scUMaxExpr:
638       return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
639     case scSMinExpr:
640       return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
641     case scUMinExpr:
642       return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
643     case scSequentialUMinExpr:
644       return ((SC *)this)
645           ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
646     case scUnknown:
647       return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
648     case scCouldNotCompute:
649       return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
650     }
651     llvm_unreachable("Unknown SCEV kind!");
652   }
653 
654   RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
655     llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
656   }
657 };
658 
659 /// Visit all nodes in the expression tree using worklist traversal.
660 ///
661 /// Visitor implements:
662 ///   // return true to follow this node.
663 ///   bool follow(const SCEV *S);
664 ///   // return true to terminate the search.
665 ///   bool isDone();
666 template <typename SV> class SCEVTraversal {
667   SV &Visitor;
668   SmallVector<const SCEV *, 8> Worklist;
669   SmallPtrSet<const SCEV *, 8> Visited;
670 
671   void push(const SCEV *S) {
672     if (Visited.insert(S).second && Visitor.follow(S))
673       Worklist.push_back(S);
674   }
675 
676 public:
677   SCEVTraversal(SV &V) : Visitor(V) {}
678 
679   void visitAll(const SCEV *Root) {
680     push(Root);
681     while (!Worklist.empty() && !Visitor.isDone()) {
682       const SCEV *S = Worklist.pop_back_val();
683 
684       switch (S->getSCEVType()) {
685       case scConstant:
686       case scUnknown:
687         continue;
688       case scPtrToInt:
689       case scTruncate:
690       case scZeroExtend:
691       case scSignExtend:
692         push(cast<SCEVCastExpr>(S)->getOperand());
693         continue;
694       case scAddExpr:
695       case scMulExpr:
696       case scSMaxExpr:
697       case scUMaxExpr:
698       case scSMinExpr:
699       case scUMinExpr:
700       case scSequentialUMinExpr:
701       case scAddRecExpr:
702         for (const auto *Op : cast<SCEVNAryExpr>(S)->operands())
703           push(Op);
704         continue;
705       case scUDivExpr: {
706         const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
707         push(UDiv->getLHS());
708         push(UDiv->getRHS());
709         continue;
710       }
711       case scCouldNotCompute:
712         llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
713       }
714       llvm_unreachable("Unknown SCEV kind!");
715     }
716   }
717 };
718 
719 /// Use SCEVTraversal to visit all nodes in the given expression tree.
720 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
721   SCEVTraversal<SV> T(Visitor);
722   T.visitAll(Root);
723 }
724 
725 /// Return true if any node in \p Root satisfies the predicate \p Pred.
726 template <typename PredTy>
727 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
728   struct FindClosure {
729     bool Found = false;
730     PredTy Pred;
731 
732     FindClosure(PredTy Pred) : Pred(Pred) {}
733 
734     bool follow(const SCEV *S) {
735       if (!Pred(S))
736         return true;
737 
738       Found = true;
739       return false;
740     }
741 
742     bool isDone() const { return Found; }
743   };
744 
745   FindClosure FC(Pred);
746   visitAll(Root, FC);
747   return FC.Found;
748 }
749 
750 /// This visitor recursively visits a SCEV expression and re-writes it.
751 /// The result from each visit is cached, so it will return the same
752 /// SCEV for the same input.
753 template <typename SC>
754 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
755 protected:
756   ScalarEvolution &SE;
757   // Memoize the result of each visit so that we only compute once for
758   // the same input SCEV. This is to avoid redundant computations when
759   // a SCEV is referenced by multiple SCEVs. Without memoization, this
760   // visit algorithm would have exponential time complexity in the worst
761   // case, causing the compiler to hang on certain tests.
762   DenseMap<const SCEV *, const SCEV *> RewriteResults;
763 
764 public:
765   SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
766 
767   const SCEV *visit(const SCEV *S) {
768     auto It = RewriteResults.find(S);
769     if (It != RewriteResults.end())
770       return It->second;
771     auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
772     auto Result = RewriteResults.try_emplace(S, Visited);
773     assert(Result.second && "Should insert a new entry");
774     return Result.first->second;
775   }
776 
777   const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
778 
779   const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
780     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
781     return Operand == Expr->getOperand()
782                ? Expr
783                : SE.getPtrToIntExpr(Operand, Expr->getType());
784   }
785 
786   const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
787     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
788     return Operand == Expr->getOperand()
789                ? Expr
790                : SE.getTruncateExpr(Operand, Expr->getType());
791   }
792 
793   const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
794     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
795     return Operand == Expr->getOperand()
796                ? Expr
797                : SE.getZeroExtendExpr(Operand, Expr->getType());
798   }
799 
800   const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
801     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
802     return Operand == Expr->getOperand()
803                ? Expr
804                : SE.getSignExtendExpr(Operand, Expr->getType());
805   }
806 
807   const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
808     SmallVector<const SCEV *, 2> Operands;
809     bool Changed = false;
810     for (auto *Op : Expr->operands()) {
811       Operands.push_back(((SC *)this)->visit(Op));
812       Changed |= Op != Operands.back();
813     }
814     return !Changed ? Expr : SE.getAddExpr(Operands);
815   }
816 
817   const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
818     SmallVector<const SCEV *, 2> Operands;
819     bool Changed = false;
820     for (auto *Op : Expr->operands()) {
821       Operands.push_back(((SC *)this)->visit(Op));
822       Changed |= Op != Operands.back();
823     }
824     return !Changed ? Expr : SE.getMulExpr(Operands);
825   }
826 
827   const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
828     auto *LHS = ((SC *)this)->visit(Expr->getLHS());
829     auto *RHS = ((SC *)this)->visit(Expr->getRHS());
830     bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
831     return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
832   }
833 
834   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
835     SmallVector<const SCEV *, 2> Operands;
836     bool Changed = false;
837     for (auto *Op : Expr->operands()) {
838       Operands.push_back(((SC *)this)->visit(Op));
839       Changed |= Op != Operands.back();
840     }
841     return !Changed ? Expr
842                     : SE.getAddRecExpr(Operands, Expr->getLoop(),
843                                        Expr->getNoWrapFlags());
844   }
845 
846   const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
847     SmallVector<const SCEV *, 2> Operands;
848     bool Changed = false;
849     for (auto *Op : Expr->operands()) {
850       Operands.push_back(((SC *)this)->visit(Op));
851       Changed |= Op != Operands.back();
852     }
853     return !Changed ? Expr : SE.getSMaxExpr(Operands);
854   }
855 
856   const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
857     SmallVector<const SCEV *, 2> Operands;
858     bool Changed = false;
859     for (auto *Op : Expr->operands()) {
860       Operands.push_back(((SC *)this)->visit(Op));
861       Changed |= Op != Operands.back();
862     }
863     return !Changed ? Expr : SE.getUMaxExpr(Operands);
864   }
865 
866   const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
867     SmallVector<const SCEV *, 2> Operands;
868     bool Changed = false;
869     for (auto *Op : Expr->operands()) {
870       Operands.push_back(((SC *)this)->visit(Op));
871       Changed |= Op != Operands.back();
872     }
873     return !Changed ? Expr : SE.getSMinExpr(Operands);
874   }
875 
876   const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
877     SmallVector<const SCEV *, 2> Operands;
878     bool Changed = false;
879     for (auto *Op : Expr->operands()) {
880       Operands.push_back(((SC *)this)->visit(Op));
881       Changed |= Op != Operands.back();
882     }
883     return !Changed ? Expr : SE.getUMinExpr(Operands);
884   }
885 
886   const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
887     SmallVector<const SCEV *, 2> Operands;
888     bool Changed = false;
889     for (auto *Op : Expr->operands()) {
890       Operands.push_back(((SC *)this)->visit(Op));
891       Changed |= Op != Operands.back();
892     }
893     return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
894   }
895 
896   const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
897 
898   const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
899     return Expr;
900   }
901 };
902 
903 using ValueToValueMap = DenseMap<const Value *, Value *>;
904 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
905 
906 /// The SCEVParameterRewriter takes a scalar evolution expression and updates
907 /// the SCEVUnknown components following the Map (Value -> SCEV).
908 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
909 public:
910   static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
911                              ValueToSCEVMapTy &Map) {
912     SCEVParameterRewriter Rewriter(SE, Map);
913     return Rewriter.visit(Scev);
914   }
915 
916   SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
917       : SCEVRewriteVisitor(SE), Map(M) {}
918 
919   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
920     auto I = Map.find(Expr->getValue());
921     if (I == Map.end())
922       return Expr;
923     return I->second;
924   }
925 
926 private:
927   ValueToSCEVMapTy &Map;
928 };
929 
930 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
931 
932 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
933 /// the Map (Loop -> SCEV) to all AddRecExprs.
934 class SCEVLoopAddRecRewriter
935     : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
936 public:
937   SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
938       : SCEVRewriteVisitor(SE), Map(M) {}
939 
940   static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
941                              ScalarEvolution &SE) {
942     SCEVLoopAddRecRewriter Rewriter(SE, Map);
943     return Rewriter.visit(Scev);
944   }
945 
946   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
947     SmallVector<const SCEV *, 2> Operands;
948     for (const SCEV *Op : Expr->operands())
949       Operands.push_back(visit(Op));
950 
951     const Loop *L = Expr->getLoop();
952     if (0 == Map.count(L))
953       return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
954 
955     return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
956   }
957 
958 private:
959   LoopToScevMapT &Map;
960 };
961 
962 } // end namespace llvm
963 
964 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
965