1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the classes used to represent and build scalar expressions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 15 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/FoldingSet.h" 18 #include "llvm/ADT/SmallPtrSet.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/ADT/iterator_range.h" 21 #include "llvm/Analysis/ScalarEvolution.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/Value.h" 24 #include "llvm/IR/ValueHandle.h" 25 #include "llvm/Support/Casting.h" 26 #include "llvm/Support/ErrorHandling.h" 27 #include <cassert> 28 #include <cstddef> 29 30 namespace llvm { 31 32 class APInt; 33 class Constant; 34 class ConstantRange; 35 class Loop; 36 class Type; 37 38 enum SCEVTypes : unsigned short { 39 // These should be ordered in terms of increasing complexity to make the 40 // folders simpler. 41 scConstant, 42 scTruncate, 43 scZeroExtend, 44 scSignExtend, 45 scAddExpr, 46 scMulExpr, 47 scUDivExpr, 48 scAddRecExpr, 49 scUMaxExpr, 50 scSMaxExpr, 51 scUMinExpr, 52 scSMinExpr, 53 scSequentialUMinExpr, 54 scPtrToInt, 55 scUnknown, 56 scCouldNotCompute 57 }; 58 59 /// This class represents a constant integer value. 60 class SCEVConstant : public SCEV { 61 friend class ScalarEvolution; 62 63 ConstantInt *V; 64 65 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) 66 : SCEV(ID, scConstant, 1), V(v) {} 67 68 public: 69 ConstantInt *getValue() const { return V; } 70 const APInt &getAPInt() const { return getValue()->getValue(); } 71 72 Type *getType() const { return V->getType(); } 73 74 /// Methods for support type inquiry through isa, cast, and dyn_cast: 75 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; } 76 }; 77 78 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) { 79 APInt Size(16, 1); 80 for (auto *Arg : Args) 81 Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize())); 82 return (unsigned short)Size.getZExtValue(); 83 } 84 85 /// This is the base class for unary cast operator classes. 86 class SCEVCastExpr : public SCEV { 87 protected: 88 std::array<const SCEV *, 1> Operands; 89 Type *Ty; 90 91 SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, 92 Type *ty); 93 94 public: 95 const SCEV *getOperand() const { return Operands[0]; } 96 const SCEV *getOperand(unsigned i) const { 97 assert(i == 0 && "Operand index out of range!"); 98 return Operands[0]; 99 } 100 using op_iterator = std::array<const SCEV *, 1>::const_iterator; 101 using op_range = iterator_range<op_iterator>; 102 103 op_range operands() const { 104 return make_range(Operands.begin(), Operands.end()); 105 } 106 size_t getNumOperands() const { return 1; } 107 Type *getType() const { return Ty; } 108 109 /// Methods for support type inquiry through isa, cast, and dyn_cast: 110 static bool classof(const SCEV *S) { 111 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate || 112 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend; 113 } 114 }; 115 116 /// This class represents a cast from a pointer to a pointer-sized integer 117 /// value. 118 class SCEVPtrToIntExpr : public SCEVCastExpr { 119 friend class ScalarEvolution; 120 121 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy); 122 123 public: 124 /// Methods for support type inquiry through isa, cast, and dyn_cast: 125 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; } 126 }; 127 128 /// This is the base class for unary integral cast operator classes. 129 class SCEVIntegralCastExpr : public SCEVCastExpr { 130 protected: 131 SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, 132 const SCEV *op, Type *ty); 133 134 public: 135 /// Methods for support type inquiry through isa, cast, and dyn_cast: 136 static bool classof(const SCEV *S) { 137 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || 138 S->getSCEVType() == scSignExtend; 139 } 140 }; 141 142 /// This class represents a truncation of an integer value to a 143 /// smaller integer value. 144 class SCEVTruncateExpr : public SCEVIntegralCastExpr { 145 friend class ScalarEvolution; 146 147 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 148 149 public: 150 /// Methods for support type inquiry through isa, cast, and dyn_cast: 151 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; } 152 }; 153 154 /// This class represents a zero extension of a small integer value 155 /// to a larger integer value. 156 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr { 157 friend class ScalarEvolution; 158 159 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 160 161 public: 162 /// Methods for support type inquiry through isa, cast, and dyn_cast: 163 static bool classof(const SCEV *S) { 164 return S->getSCEVType() == scZeroExtend; 165 } 166 }; 167 168 /// This class represents a sign extension of a small integer value 169 /// to a larger integer value. 170 class SCEVSignExtendExpr : public SCEVIntegralCastExpr { 171 friend class ScalarEvolution; 172 173 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 174 175 public: 176 /// Methods for support type inquiry through isa, cast, and dyn_cast: 177 static bool classof(const SCEV *S) { 178 return S->getSCEVType() == scSignExtend; 179 } 180 }; 181 182 /// This node is a base class providing common functionality for 183 /// n'ary operators. 184 class SCEVNAryExpr : public SCEV { 185 protected: 186 // Since SCEVs are immutable, ScalarEvolution allocates operand 187 // arrays with its SCEVAllocator, so this class just needs a simple 188 // pointer rather than a more elaborate vector-like data structure. 189 // This also avoids the need for a non-trivial destructor. 190 const SCEV *const *Operands; 191 size_t NumOperands; 192 193 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 194 const SCEV *const *O, size_t N) 195 : SCEV(ID, T, computeExpressionSize(makeArrayRef(O, N))), Operands(O), 196 NumOperands(N) {} 197 198 public: 199 size_t getNumOperands() const { return NumOperands; } 200 201 const SCEV *getOperand(unsigned i) const { 202 assert(i < NumOperands && "Operand index out of range!"); 203 return Operands[i]; 204 } 205 206 using op_iterator = const SCEV *const *; 207 using op_range = iterator_range<op_iterator>; 208 209 op_iterator op_begin() const { return Operands; } 210 op_iterator op_end() const { return Operands + NumOperands; } 211 op_range operands() const { return make_range(op_begin(), op_end()); } 212 213 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { 214 return (NoWrapFlags)(SubclassData & Mask); 215 } 216 217 bool hasNoUnsignedWrap() const { 218 return getNoWrapFlags(FlagNUW) != FlagAnyWrap; 219 } 220 221 bool hasNoSignedWrap() const { 222 return getNoWrapFlags(FlagNSW) != FlagAnyWrap; 223 } 224 225 bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; } 226 227 /// Methods for support type inquiry through isa, cast, and dyn_cast: 228 static bool classof(const SCEV *S) { 229 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 230 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 231 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr || 232 S->getSCEVType() == scSequentialUMinExpr || 233 S->getSCEVType() == scAddRecExpr; 234 } 235 }; 236 237 /// This node is the base class for n'ary commutative operators. 238 class SCEVCommutativeExpr : public SCEVNAryExpr { 239 protected: 240 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 241 const SCEV *const *O, size_t N) 242 : SCEVNAryExpr(ID, T, O, N) {} 243 244 public: 245 /// Methods for support type inquiry through isa, cast, and dyn_cast: 246 static bool classof(const SCEV *S) { 247 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 248 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 249 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr; 250 } 251 252 /// Set flags for a non-recurrence without clearing previously set flags. 253 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 254 }; 255 256 /// This node represents an addition of some number of SCEVs. 257 class SCEVAddExpr : public SCEVCommutativeExpr { 258 friend class ScalarEvolution; 259 260 Type *Ty; 261 262 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 263 : SCEVCommutativeExpr(ID, scAddExpr, O, N) { 264 auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) { 265 return Op->getType()->isPointerTy(); 266 }); 267 if (FirstPointerTypedOp != operands().end()) 268 Ty = (*FirstPointerTypedOp)->getType(); 269 else 270 Ty = getOperand(0)->getType(); 271 } 272 273 public: 274 Type *getType() const { return Ty; } 275 276 /// Methods for support type inquiry through isa, cast, and dyn_cast: 277 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; } 278 }; 279 280 /// This node represents multiplication of some number of SCEVs. 281 class SCEVMulExpr : public SCEVCommutativeExpr { 282 friend class ScalarEvolution; 283 284 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 285 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {} 286 287 public: 288 Type *getType() const { return getOperand(0)->getType(); } 289 290 /// Methods for support type inquiry through isa, cast, and dyn_cast: 291 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; } 292 }; 293 294 /// This class represents a binary unsigned division operation. 295 class SCEVUDivExpr : public SCEV { 296 friend class ScalarEvolution; 297 298 std::array<const SCEV *, 2> Operands; 299 300 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) 301 : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) { 302 Operands[0] = lhs; 303 Operands[1] = rhs; 304 } 305 306 public: 307 const SCEV *getLHS() const { return Operands[0]; } 308 const SCEV *getRHS() const { return Operands[1]; } 309 size_t getNumOperands() const { return 2; } 310 const SCEV *getOperand(unsigned i) const { 311 assert((i == 0 || i == 1) && "Operand index out of range!"); 312 return i == 0 ? getLHS() : getRHS(); 313 } 314 315 using op_iterator = std::array<const SCEV *, 2>::const_iterator; 316 using op_range = iterator_range<op_iterator>; 317 op_range operands() const { 318 return make_range(Operands.begin(), Operands.end()); 319 } 320 321 Type *getType() const { 322 // In most cases the types of LHS and RHS will be the same, but in some 323 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't 324 // depend on the type for correctness, but handling types carefully can 325 // avoid extra casts in the SCEVExpander. The LHS is more likely to be 326 // a pointer type than the RHS, so use the RHS' type here. 327 return getRHS()->getType(); 328 } 329 330 /// Methods for support type inquiry through isa, cast, and dyn_cast: 331 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; } 332 }; 333 334 /// This node represents a polynomial recurrence on the trip count 335 /// of the specified loop. This is the primary focus of the 336 /// ScalarEvolution framework; all the other SCEV subclasses are 337 /// mostly just supporting infrastructure to allow SCEVAddRecExpr 338 /// expressions to be created and analyzed. 339 /// 340 /// All operands of an AddRec are required to be loop invariant. 341 /// 342 class SCEVAddRecExpr : public SCEVNAryExpr { 343 friend class ScalarEvolution; 344 345 const Loop *L; 346 347 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N, 348 const Loop *l) 349 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} 350 351 public: 352 Type *getType() const { return getStart()->getType(); } 353 const SCEV *getStart() const { return Operands[0]; } 354 const Loop *getLoop() const { return L; } 355 356 /// Constructs and returns the recurrence indicating how much this 357 /// expression steps by. If this is a polynomial of degree N, it 358 /// returns a chrec of degree N-1. We cannot determine whether 359 /// the step recurrence has self-wraparound. 360 const SCEV *getStepRecurrence(ScalarEvolution &SE) const { 361 if (isAffine()) 362 return getOperand(1); 363 return SE.getAddRecExpr( 364 SmallVector<const SCEV *, 3>(op_begin() + 1, op_end()), getLoop(), 365 FlagAnyWrap); 366 } 367 368 /// Return true if this represents an expression A + B*x where A 369 /// and B are loop invariant values. 370 bool isAffine() const { 371 // We know that the start value is invariant. This expression is thus 372 // affine iff the step is also invariant. 373 return getNumOperands() == 2; 374 } 375 376 /// Return true if this represents an expression A + B*x + C*x^2 377 /// where A, B and C are loop invariant values. This corresponds 378 /// to an addrec of the form {L,+,M,+,N} 379 bool isQuadratic() const { return getNumOperands() == 3; } 380 381 /// Set flags for a recurrence without clearing any previously set flags. 382 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here 383 /// to make it easier to propagate flags. 384 void setNoWrapFlags(NoWrapFlags Flags) { 385 if (Flags & (FlagNUW | FlagNSW)) 386 Flags = ScalarEvolution::setFlags(Flags, FlagNW); 387 SubclassData |= Flags; 388 } 389 390 /// Return the value of this chain of recurrences at the specified 391 /// iteration number. 392 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; 393 394 /// Return the value of this chain of recurrences at the specified iteration 395 /// number. Takes an explicit list of operands to represent an AddRec. 396 static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands, 397 const SCEV *It, ScalarEvolution &SE); 398 399 /// Return the number of iterations of this loop that produce 400 /// values in the specified constant range. Another way of 401 /// looking at this is that it returns the first iteration number 402 /// where the value is not in the condition, thus computing the 403 /// exit count. If the iteration count can't be computed, an 404 /// instance of SCEVCouldNotCompute is returned. 405 const SCEV *getNumIterationsInRange(const ConstantRange &Range, 406 ScalarEvolution &SE) const; 407 408 /// Return an expression representing the value of this expression 409 /// one iteration of the loop ahead. 410 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const; 411 412 /// Methods for support type inquiry through isa, cast, and dyn_cast: 413 static bool classof(const SCEV *S) { 414 return S->getSCEVType() == scAddRecExpr; 415 } 416 }; 417 418 /// This node is the base class min/max selections. 419 class SCEVMinMaxExpr : public SCEVCommutativeExpr { 420 friend class ScalarEvolution; 421 422 static bool isMinMaxType(enum SCEVTypes T) { 423 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr || 424 T == scUMinExpr; 425 } 426 427 protected: 428 /// Note: Constructing subclasses via this constructor is allowed 429 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 430 const SCEV *const *O, size_t N) 431 : SCEVCommutativeExpr(ID, T, O, N) { 432 assert(isMinMaxType(T)); 433 // Min and max never overflow 434 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 435 } 436 437 public: 438 Type *getType() const { return getOperand(0)->getType(); } 439 440 static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); } 441 442 static enum SCEVTypes negate(enum SCEVTypes T) { 443 switch (T) { 444 case scSMaxExpr: 445 return scSMinExpr; 446 case scSMinExpr: 447 return scSMaxExpr; 448 case scUMaxExpr: 449 return scUMinExpr; 450 case scUMinExpr: 451 return scUMaxExpr; 452 default: 453 llvm_unreachable("Not a min or max SCEV type!"); 454 } 455 } 456 }; 457 458 /// This class represents a signed maximum selection. 459 class SCEVSMaxExpr : public SCEVMinMaxExpr { 460 friend class ScalarEvolution; 461 462 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 463 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {} 464 465 public: 466 /// Methods for support type inquiry through isa, cast, and dyn_cast: 467 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; } 468 }; 469 470 /// This class represents an unsigned maximum selection. 471 class SCEVUMaxExpr : public SCEVMinMaxExpr { 472 friend class ScalarEvolution; 473 474 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 475 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {} 476 477 public: 478 /// Methods for support type inquiry through isa, cast, and dyn_cast: 479 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; } 480 }; 481 482 /// This class represents a signed minimum selection. 483 class SCEVSMinExpr : public SCEVMinMaxExpr { 484 friend class ScalarEvolution; 485 486 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 487 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {} 488 489 public: 490 /// Methods for support type inquiry through isa, cast, and dyn_cast: 491 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; } 492 }; 493 494 /// This class represents an unsigned minimum selection. 495 class SCEVUMinExpr : public SCEVMinMaxExpr { 496 friend class ScalarEvolution; 497 498 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 499 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {} 500 501 public: 502 /// Methods for support type inquiry through isa, cast, and dyn_cast: 503 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; } 504 }; 505 506 /// This node is the base class for sequential/in-order min/max selections. 507 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they 508 /// are early-returning upon reaching saturation point. 509 /// I.e. given `0 umin_seq poison`, the result will be `0`, 510 /// while the result of `0 umin poison` is `poison`. 511 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { 512 friend class ScalarEvolution; 513 514 static bool isSequentialMinMaxType(enum SCEVTypes T) { 515 return T == scSequentialUMinExpr; 516 } 517 518 /// Set flags for a non-recurrence without clearing previously set flags. 519 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 520 521 protected: 522 /// Note: Constructing subclasses via this constructor is allowed 523 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 524 const SCEV *const *O, size_t N) 525 : SCEVNAryExpr(ID, T, O, N) { 526 assert(isSequentialMinMaxType(T)); 527 // Min and max never overflow 528 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 529 } 530 531 public: 532 Type *getType() const { return getOperand(0)->getType(); } 533 534 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) { 535 assert(isSequentialMinMaxType(Ty)); 536 switch (Ty) { 537 case scSequentialUMinExpr: 538 return scUMinExpr; 539 default: 540 llvm_unreachable("Not a sequential min/max type."); 541 } 542 } 543 544 SCEVTypes getEquivalentNonSequentialSCEVType() const { 545 return getEquivalentNonSequentialSCEVType(getSCEVType()); 546 } 547 548 static bool classof(const SCEV *S) { 549 return isSequentialMinMaxType(S->getSCEVType()); 550 } 551 }; 552 553 /// This class represents a sequential/in-order unsigned minimum selection. 554 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { 555 friend class ScalarEvolution; 556 557 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, 558 size_t N) 559 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {} 560 561 public: 562 /// Methods for support type inquiry through isa, cast, and dyn_cast: 563 static bool classof(const SCEV *S) { 564 return S->getSCEVType() == scSequentialUMinExpr; 565 } 566 }; 567 568 /// This means that we are dealing with an entirely unknown SCEV 569 /// value, and only represent it as its LLVM Value. This is the 570 /// "bottom" value for the analysis. 571 class SCEVUnknown final : public SCEV, private CallbackVH { 572 friend class ScalarEvolution; 573 574 /// The parent ScalarEvolution value. This is used to update the 575 /// parent's maps when the value associated with a SCEVUnknown is 576 /// deleted or RAUW'd. 577 ScalarEvolution *SE; 578 579 /// The next pointer in the linked list of all SCEVUnknown 580 /// instances owned by a ScalarEvolution. 581 SCEVUnknown *Next; 582 583 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se, 584 SCEVUnknown *next) 585 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {} 586 587 // Implement CallbackVH. 588 void deleted() override; 589 void allUsesReplacedWith(Value *New) override; 590 591 public: 592 Value *getValue() const { return getValPtr(); } 593 594 /// @{ 595 /// Test whether this is a special constant representing a type 596 /// size, alignment, or field offset in a target-independent 597 /// manner, and hasn't happened to have been folded with other 598 /// operations into something unrecognizable. This is mainly only 599 /// useful for pretty-printing and other situations where it isn't 600 /// absolutely required for these to succeed. 601 bool isSizeOf(Type *&AllocTy) const; 602 bool isAlignOf(Type *&AllocTy) const; 603 bool isOffsetOf(Type *&STy, Constant *&FieldNo) const; 604 /// @} 605 606 Type *getType() const { return getValPtr()->getType(); } 607 608 /// Methods for support type inquiry through isa, cast, and dyn_cast: 609 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; } 610 }; 611 612 /// This class defines a simple visitor class that may be used for 613 /// various SCEV analysis purposes. 614 template <typename SC, typename RetVal = void> struct SCEVVisitor { 615 RetVal visit(const SCEV *S) { 616 switch (S->getSCEVType()) { 617 case scConstant: 618 return ((SC *)this)->visitConstant((const SCEVConstant *)S); 619 case scPtrToInt: 620 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S); 621 case scTruncate: 622 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S); 623 case scZeroExtend: 624 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S); 625 case scSignExtend: 626 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S); 627 case scAddExpr: 628 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S); 629 case scMulExpr: 630 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S); 631 case scUDivExpr: 632 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S); 633 case scAddRecExpr: 634 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S); 635 case scSMaxExpr: 636 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S); 637 case scUMaxExpr: 638 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S); 639 case scSMinExpr: 640 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S); 641 case scUMinExpr: 642 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S); 643 case scSequentialUMinExpr: 644 return ((SC *)this) 645 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S); 646 case scUnknown: 647 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S); 648 case scCouldNotCompute: 649 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S); 650 } 651 llvm_unreachable("Unknown SCEV kind!"); 652 } 653 654 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) { 655 llvm_unreachable("Invalid use of SCEVCouldNotCompute!"); 656 } 657 }; 658 659 /// Visit all nodes in the expression tree using worklist traversal. 660 /// 661 /// Visitor implements: 662 /// // return true to follow this node. 663 /// bool follow(const SCEV *S); 664 /// // return true to terminate the search. 665 /// bool isDone(); 666 template <typename SV> class SCEVTraversal { 667 SV &Visitor; 668 SmallVector<const SCEV *, 8> Worklist; 669 SmallPtrSet<const SCEV *, 8> Visited; 670 671 void push(const SCEV *S) { 672 if (Visited.insert(S).second && Visitor.follow(S)) 673 Worklist.push_back(S); 674 } 675 676 public: 677 SCEVTraversal(SV &V) : Visitor(V) {} 678 679 void visitAll(const SCEV *Root) { 680 push(Root); 681 while (!Worklist.empty() && !Visitor.isDone()) { 682 const SCEV *S = Worklist.pop_back_val(); 683 684 switch (S->getSCEVType()) { 685 case scConstant: 686 case scUnknown: 687 continue; 688 case scPtrToInt: 689 case scTruncate: 690 case scZeroExtend: 691 case scSignExtend: 692 push(cast<SCEVCastExpr>(S)->getOperand()); 693 continue; 694 case scAddExpr: 695 case scMulExpr: 696 case scSMaxExpr: 697 case scUMaxExpr: 698 case scSMinExpr: 699 case scUMinExpr: 700 case scSequentialUMinExpr: 701 case scAddRecExpr: 702 for (const auto *Op : cast<SCEVNAryExpr>(S)->operands()) 703 push(Op); 704 continue; 705 case scUDivExpr: { 706 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); 707 push(UDiv->getLHS()); 708 push(UDiv->getRHS()); 709 continue; 710 } 711 case scCouldNotCompute: 712 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 713 } 714 llvm_unreachable("Unknown SCEV kind!"); 715 } 716 } 717 }; 718 719 /// Use SCEVTraversal to visit all nodes in the given expression tree. 720 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) { 721 SCEVTraversal<SV> T(Visitor); 722 T.visitAll(Root); 723 } 724 725 /// Return true if any node in \p Root satisfies the predicate \p Pred. 726 template <typename PredTy> 727 bool SCEVExprContains(const SCEV *Root, PredTy Pred) { 728 struct FindClosure { 729 bool Found = false; 730 PredTy Pred; 731 732 FindClosure(PredTy Pred) : Pred(Pred) {} 733 734 bool follow(const SCEV *S) { 735 if (!Pred(S)) 736 return true; 737 738 Found = true; 739 return false; 740 } 741 742 bool isDone() const { return Found; } 743 }; 744 745 FindClosure FC(Pred); 746 visitAll(Root, FC); 747 return FC.Found; 748 } 749 750 /// This visitor recursively visits a SCEV expression and re-writes it. 751 /// The result from each visit is cached, so it will return the same 752 /// SCEV for the same input. 753 template <typename SC> 754 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> { 755 protected: 756 ScalarEvolution &SE; 757 // Memoize the result of each visit so that we only compute once for 758 // the same input SCEV. This is to avoid redundant computations when 759 // a SCEV is referenced by multiple SCEVs. Without memoization, this 760 // visit algorithm would have exponential time complexity in the worst 761 // case, causing the compiler to hang on certain tests. 762 DenseMap<const SCEV *, const SCEV *> RewriteResults; 763 764 public: 765 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} 766 767 const SCEV *visit(const SCEV *S) { 768 auto It = RewriteResults.find(S); 769 if (It != RewriteResults.end()) 770 return It->second; 771 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S); 772 auto Result = RewriteResults.try_emplace(S, Visited); 773 assert(Result.second && "Should insert a new entry"); 774 return Result.first->second; 775 } 776 777 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } 778 779 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { 780 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 781 return Operand == Expr->getOperand() 782 ? Expr 783 : SE.getPtrToIntExpr(Operand, Expr->getType()); 784 } 785 786 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { 787 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 788 return Operand == Expr->getOperand() 789 ? Expr 790 : SE.getTruncateExpr(Operand, Expr->getType()); 791 } 792 793 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 794 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 795 return Operand == Expr->getOperand() 796 ? Expr 797 : SE.getZeroExtendExpr(Operand, Expr->getType()); 798 } 799 800 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 801 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 802 return Operand == Expr->getOperand() 803 ? Expr 804 : SE.getSignExtendExpr(Operand, Expr->getType()); 805 } 806 807 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { 808 SmallVector<const SCEV *, 2> Operands; 809 bool Changed = false; 810 for (auto *Op : Expr->operands()) { 811 Operands.push_back(((SC *)this)->visit(Op)); 812 Changed |= Op != Operands.back(); 813 } 814 return !Changed ? Expr : SE.getAddExpr(Operands); 815 } 816 817 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { 818 SmallVector<const SCEV *, 2> Operands; 819 bool Changed = false; 820 for (auto *Op : Expr->operands()) { 821 Operands.push_back(((SC *)this)->visit(Op)); 822 Changed |= Op != Operands.back(); 823 } 824 return !Changed ? Expr : SE.getMulExpr(Operands); 825 } 826 827 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { 828 auto *LHS = ((SC *)this)->visit(Expr->getLHS()); 829 auto *RHS = ((SC *)this)->visit(Expr->getRHS()); 830 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS(); 831 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS); 832 } 833 834 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 835 SmallVector<const SCEV *, 2> Operands; 836 bool Changed = false; 837 for (auto *Op : Expr->operands()) { 838 Operands.push_back(((SC *)this)->visit(Op)); 839 Changed |= Op != Operands.back(); 840 } 841 return !Changed ? Expr 842 : SE.getAddRecExpr(Operands, Expr->getLoop(), 843 Expr->getNoWrapFlags()); 844 } 845 846 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { 847 SmallVector<const SCEV *, 2> Operands; 848 bool Changed = false; 849 for (auto *Op : Expr->operands()) { 850 Operands.push_back(((SC *)this)->visit(Op)); 851 Changed |= Op != Operands.back(); 852 } 853 return !Changed ? Expr : SE.getSMaxExpr(Operands); 854 } 855 856 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { 857 SmallVector<const SCEV *, 2> Operands; 858 bool Changed = false; 859 for (auto *Op : Expr->operands()) { 860 Operands.push_back(((SC *)this)->visit(Op)); 861 Changed |= Op != Operands.back(); 862 } 863 return !Changed ? Expr : SE.getUMaxExpr(Operands); 864 } 865 866 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) { 867 SmallVector<const SCEV *, 2> Operands; 868 bool Changed = false; 869 for (auto *Op : Expr->operands()) { 870 Operands.push_back(((SC *)this)->visit(Op)); 871 Changed |= Op != Operands.back(); 872 } 873 return !Changed ? Expr : SE.getSMinExpr(Operands); 874 } 875 876 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) { 877 SmallVector<const SCEV *, 2> Operands; 878 bool Changed = false; 879 for (auto *Op : Expr->operands()) { 880 Operands.push_back(((SC *)this)->visit(Op)); 881 Changed |= Op != Operands.back(); 882 } 883 return !Changed ? Expr : SE.getUMinExpr(Operands); 884 } 885 886 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { 887 SmallVector<const SCEV *, 2> Operands; 888 bool Changed = false; 889 for (auto *Op : Expr->operands()) { 890 Operands.push_back(((SC *)this)->visit(Op)); 891 Changed |= Op != Operands.back(); 892 } 893 return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true); 894 } 895 896 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } 897 898 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { 899 return Expr; 900 } 901 }; 902 903 using ValueToValueMap = DenseMap<const Value *, Value *>; 904 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>; 905 906 /// The SCEVParameterRewriter takes a scalar evolution expression and updates 907 /// the SCEVUnknown components following the Map (Value -> SCEV). 908 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> { 909 public: 910 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, 911 ValueToSCEVMapTy &Map) { 912 SCEVParameterRewriter Rewriter(SE, Map); 913 return Rewriter.visit(Scev); 914 } 915 916 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) 917 : SCEVRewriteVisitor(SE), Map(M) {} 918 919 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 920 auto I = Map.find(Expr->getValue()); 921 if (I == Map.end()) 922 return Expr; 923 return I->second; 924 } 925 926 private: 927 ValueToSCEVMapTy ⤅ 928 }; 929 930 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>; 931 932 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies 933 /// the Map (Loop -> SCEV) to all AddRecExprs. 934 class SCEVLoopAddRecRewriter 935 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> { 936 public: 937 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) 938 : SCEVRewriteVisitor(SE), Map(M) {} 939 940 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, 941 ScalarEvolution &SE) { 942 SCEVLoopAddRecRewriter Rewriter(SE, Map); 943 return Rewriter.visit(Scev); 944 } 945 946 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 947 SmallVector<const SCEV *, 2> Operands; 948 for (const SCEV *Op : Expr->operands()) 949 Operands.push_back(visit(Op)); 950 951 const Loop *L = Expr->getLoop(); 952 if (0 == Map.count(L)) 953 return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); 954 955 return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE); 956 } 957 958 private: 959 LoopToScevMapT ⤅ 960 }; 961 962 } // end namespace llvm 963 964 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 965