1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the classes used to represent and build scalar expressions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 15 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/SmallPtrSet.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/iterator_range.h" 20 #include "llvm/Analysis/ScalarEvolution.h" 21 #include "llvm/IR/Constants.h" 22 #include "llvm/IR/ValueHandle.h" 23 #include "llvm/Support/Casting.h" 24 #include "llvm/Support/ErrorHandling.h" 25 #include <cassert> 26 #include <cstddef> 27 28 namespace llvm { 29 30 class APInt; 31 class Constant; 32 class ConstantInt; 33 class ConstantRange; 34 class Loop; 35 class Type; 36 class Value; 37 38 enum SCEVTypes : unsigned short { 39 // These should be ordered in terms of increasing complexity to make the 40 // folders simpler. 41 scConstant, 42 scVScale, 43 scTruncate, 44 scZeroExtend, 45 scSignExtend, 46 scAddExpr, 47 scMulExpr, 48 scUDivExpr, 49 scAddRecExpr, 50 scUMaxExpr, 51 scSMaxExpr, 52 scUMinExpr, 53 scSMinExpr, 54 scSequentialUMinExpr, 55 scPtrToInt, 56 scUnknown, 57 scCouldNotCompute 58 }; 59 60 /// This class represents a constant integer value. 61 class SCEVConstant : public SCEV { 62 friend class ScalarEvolution; 63 64 ConstantInt *V; 65 66 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) 67 : SCEV(ID, scConstant, 1), V(v) {} 68 69 public: 70 ConstantInt *getValue() const { return V; } 71 const APInt &getAPInt() const { return getValue()->getValue(); } 72 73 Type *getType() const { return V->getType(); } 74 75 /// Methods for support type inquiry through isa, cast, and dyn_cast: 76 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; } 77 }; 78 79 /// This class represents the value of vscale, as used when defining the length 80 /// of a scalable vector or returned by the llvm.vscale() intrinsic. 81 class SCEVVScale : public SCEV { 82 friend class ScalarEvolution; 83 84 SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty) 85 : SCEV(ID, scVScale, 0), Ty(ty) {} 86 87 Type *Ty; 88 89 public: 90 Type *getType() const { return Ty; } 91 92 /// Methods for support type inquiry through isa, cast, and dyn_cast: 93 static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; } 94 }; 95 96 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) { 97 APInt Size(16, 1); 98 for (const auto *Arg : Args) 99 Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize())); 100 return (unsigned short)Size.getZExtValue(); 101 } 102 103 /// This is the base class for unary cast operator classes. 104 class SCEVCastExpr : public SCEV { 105 protected: 106 const SCEV *Op; 107 Type *Ty; 108 109 SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, 110 Type *ty); 111 112 public: 113 const SCEV *getOperand() const { return Op; } 114 const SCEV *getOperand(unsigned i) const { 115 assert(i == 0 && "Operand index out of range!"); 116 return Op; 117 } 118 ArrayRef<const SCEV *> operands() const { return Op; } 119 size_t getNumOperands() const { return 1; } 120 Type *getType() const { return Ty; } 121 122 /// Methods for support type inquiry through isa, cast, and dyn_cast: 123 static bool classof(const SCEV *S) { 124 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate || 125 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend; 126 } 127 }; 128 129 /// This class represents a cast from a pointer to a pointer-sized integer 130 /// value. 131 class SCEVPtrToIntExpr : public SCEVCastExpr { 132 friend class ScalarEvolution; 133 134 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy); 135 136 public: 137 /// Methods for support type inquiry through isa, cast, and dyn_cast: 138 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; } 139 }; 140 141 /// This is the base class for unary integral cast operator classes. 142 class SCEVIntegralCastExpr : public SCEVCastExpr { 143 protected: 144 SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, 145 const SCEV *op, Type *ty); 146 147 public: 148 /// Methods for support type inquiry through isa, cast, and dyn_cast: 149 static bool classof(const SCEV *S) { 150 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || 151 S->getSCEVType() == scSignExtend; 152 } 153 }; 154 155 /// This class represents a truncation of an integer value to a 156 /// smaller integer value. 157 class SCEVTruncateExpr : public SCEVIntegralCastExpr { 158 friend class ScalarEvolution; 159 160 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 161 162 public: 163 /// Methods for support type inquiry through isa, cast, and dyn_cast: 164 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; } 165 }; 166 167 /// This class represents a zero extension of a small integer value 168 /// to a larger integer value. 169 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr { 170 friend class ScalarEvolution; 171 172 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 173 174 public: 175 /// Methods for support type inquiry through isa, cast, and dyn_cast: 176 static bool classof(const SCEV *S) { 177 return S->getSCEVType() == scZeroExtend; 178 } 179 }; 180 181 /// This class represents a sign extension of a small integer value 182 /// to a larger integer value. 183 class SCEVSignExtendExpr : public SCEVIntegralCastExpr { 184 friend class ScalarEvolution; 185 186 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 187 188 public: 189 /// Methods for support type inquiry through isa, cast, and dyn_cast: 190 static bool classof(const SCEV *S) { 191 return S->getSCEVType() == scSignExtend; 192 } 193 }; 194 195 /// This node is a base class providing common functionality for 196 /// n'ary operators. 197 class SCEVNAryExpr : public SCEV { 198 protected: 199 // Since SCEVs are immutable, ScalarEvolution allocates operand 200 // arrays with its SCEVAllocator, so this class just needs a simple 201 // pointer rather than a more elaborate vector-like data structure. 202 // This also avoids the need for a non-trivial destructor. 203 const SCEV *const *Operands; 204 size_t NumOperands; 205 206 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 207 const SCEV *const *O, size_t N) 208 : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O), 209 NumOperands(N) {} 210 211 public: 212 size_t getNumOperands() const { return NumOperands; } 213 214 const SCEV *getOperand(unsigned i) const { 215 assert(i < NumOperands && "Operand index out of range!"); 216 return Operands[i]; 217 } 218 219 ArrayRef<const SCEV *> operands() const { 220 return ArrayRef(Operands, NumOperands); 221 } 222 223 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { 224 return (NoWrapFlags)(SubclassData & Mask); 225 } 226 227 bool hasNoUnsignedWrap() const { 228 return getNoWrapFlags(FlagNUW) != FlagAnyWrap; 229 } 230 231 bool hasNoSignedWrap() const { 232 return getNoWrapFlags(FlagNSW) != FlagAnyWrap; 233 } 234 235 bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; } 236 237 /// Methods for support type inquiry through isa, cast, and dyn_cast: 238 static bool classof(const SCEV *S) { 239 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 240 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 241 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr || 242 S->getSCEVType() == scSequentialUMinExpr || 243 S->getSCEVType() == scAddRecExpr; 244 } 245 }; 246 247 /// This node is the base class for n'ary commutative operators. 248 class SCEVCommutativeExpr : public SCEVNAryExpr { 249 protected: 250 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 251 const SCEV *const *O, size_t N) 252 : SCEVNAryExpr(ID, T, O, N) {} 253 254 public: 255 /// Methods for support type inquiry through isa, cast, and dyn_cast: 256 static bool classof(const SCEV *S) { 257 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 258 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 259 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr; 260 } 261 262 /// Set flags for a non-recurrence without clearing previously set flags. 263 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 264 }; 265 266 /// This node represents an addition of some number of SCEVs. 267 class SCEVAddExpr : public SCEVCommutativeExpr { 268 friend class ScalarEvolution; 269 270 Type *Ty; 271 272 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 273 : SCEVCommutativeExpr(ID, scAddExpr, O, N) { 274 auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) { 275 return Op->getType()->isPointerTy(); 276 }); 277 if (FirstPointerTypedOp != operands().end()) 278 Ty = (*FirstPointerTypedOp)->getType(); 279 else 280 Ty = getOperand(0)->getType(); 281 } 282 283 public: 284 Type *getType() const { return Ty; } 285 286 /// Methods for support type inquiry through isa, cast, and dyn_cast: 287 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; } 288 }; 289 290 /// This node represents multiplication of some number of SCEVs. 291 class SCEVMulExpr : public SCEVCommutativeExpr { 292 friend class ScalarEvolution; 293 294 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 295 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {} 296 297 public: 298 Type *getType() const { return getOperand(0)->getType(); } 299 300 /// Methods for support type inquiry through isa, cast, and dyn_cast: 301 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; } 302 }; 303 304 /// This class represents a binary unsigned division operation. 305 class SCEVUDivExpr : public SCEV { 306 friend class ScalarEvolution; 307 308 std::array<const SCEV *, 2> Operands; 309 310 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) 311 : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) { 312 Operands[0] = lhs; 313 Operands[1] = rhs; 314 } 315 316 public: 317 const SCEV *getLHS() const { return Operands[0]; } 318 const SCEV *getRHS() const { return Operands[1]; } 319 size_t getNumOperands() const { return 2; } 320 const SCEV *getOperand(unsigned i) const { 321 assert((i == 0 || i == 1) && "Operand index out of range!"); 322 return i == 0 ? getLHS() : getRHS(); 323 } 324 325 ArrayRef<const SCEV *> operands() const { return Operands; } 326 327 Type *getType() const { 328 // In most cases the types of LHS and RHS will be the same, but in some 329 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't 330 // depend on the type for correctness, but handling types carefully can 331 // avoid extra casts in the SCEVExpander. The LHS is more likely to be 332 // a pointer type than the RHS, so use the RHS' type here. 333 return getRHS()->getType(); 334 } 335 336 /// Methods for support type inquiry through isa, cast, and dyn_cast: 337 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; } 338 }; 339 340 /// This node represents a polynomial recurrence on the trip count 341 /// of the specified loop. This is the primary focus of the 342 /// ScalarEvolution framework; all the other SCEV subclasses are 343 /// mostly just supporting infrastructure to allow SCEVAddRecExpr 344 /// expressions to be created and analyzed. 345 /// 346 /// All operands of an AddRec are required to be loop invariant. 347 /// 348 class SCEVAddRecExpr : public SCEVNAryExpr { 349 friend class ScalarEvolution; 350 351 const Loop *L; 352 353 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N, 354 const Loop *l) 355 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} 356 357 public: 358 Type *getType() const { return getStart()->getType(); } 359 const SCEV *getStart() const { return Operands[0]; } 360 const Loop *getLoop() const { return L; } 361 362 /// Constructs and returns the recurrence indicating how much this 363 /// expression steps by. If this is a polynomial of degree N, it 364 /// returns a chrec of degree N-1. We cannot determine whether 365 /// the step recurrence has self-wraparound. 366 const SCEV *getStepRecurrence(ScalarEvolution &SE) const { 367 if (isAffine()) 368 return getOperand(1); 369 return SE.getAddRecExpr( 370 SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(), 371 FlagAnyWrap); 372 } 373 374 /// Return true if this represents an expression A + B*x where A 375 /// and B are loop invariant values. 376 bool isAffine() const { 377 // We know that the start value is invariant. This expression is thus 378 // affine iff the step is also invariant. 379 return getNumOperands() == 2; 380 } 381 382 /// Return true if this represents an expression A + B*x + C*x^2 383 /// where A, B and C are loop invariant values. This corresponds 384 /// to an addrec of the form {L,+,M,+,N} 385 bool isQuadratic() const { return getNumOperands() == 3; } 386 387 /// Set flags for a recurrence without clearing any previously set flags. 388 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here 389 /// to make it easier to propagate flags. 390 void setNoWrapFlags(NoWrapFlags Flags) { 391 if (Flags & (FlagNUW | FlagNSW)) 392 Flags = ScalarEvolution::setFlags(Flags, FlagNW); 393 SubclassData |= Flags; 394 } 395 396 /// Return the value of this chain of recurrences at the specified 397 /// iteration number. 398 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; 399 400 /// Return the value of this chain of recurrences at the specified iteration 401 /// number. Takes an explicit list of operands to represent an AddRec. 402 static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands, 403 const SCEV *It, ScalarEvolution &SE); 404 405 /// Return the number of iterations of this loop that produce 406 /// values in the specified constant range. Another way of 407 /// looking at this is that it returns the first iteration number 408 /// where the value is not in the condition, thus computing the 409 /// exit count. If the iteration count can't be computed, an 410 /// instance of SCEVCouldNotCompute is returned. 411 const SCEV *getNumIterationsInRange(const ConstantRange &Range, 412 ScalarEvolution &SE) const; 413 414 /// Return an expression representing the value of this expression 415 /// one iteration of the loop ahead. 416 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const; 417 418 /// Methods for support type inquiry through isa, cast, and dyn_cast: 419 static bool classof(const SCEV *S) { 420 return S->getSCEVType() == scAddRecExpr; 421 } 422 }; 423 424 /// This node is the base class min/max selections. 425 class SCEVMinMaxExpr : public SCEVCommutativeExpr { 426 friend class ScalarEvolution; 427 428 static bool isMinMaxType(enum SCEVTypes T) { 429 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr || 430 T == scUMinExpr; 431 } 432 433 protected: 434 /// Note: Constructing subclasses via this constructor is allowed 435 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 436 const SCEV *const *O, size_t N) 437 : SCEVCommutativeExpr(ID, T, O, N) { 438 assert(isMinMaxType(T)); 439 // Min and max never overflow 440 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 441 } 442 443 public: 444 Type *getType() const { return getOperand(0)->getType(); } 445 446 static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); } 447 448 static enum SCEVTypes negate(enum SCEVTypes T) { 449 switch (T) { 450 case scSMaxExpr: 451 return scSMinExpr; 452 case scSMinExpr: 453 return scSMaxExpr; 454 case scUMaxExpr: 455 return scUMinExpr; 456 case scUMinExpr: 457 return scUMaxExpr; 458 default: 459 llvm_unreachable("Not a min or max SCEV type!"); 460 } 461 } 462 }; 463 464 /// This class represents a signed maximum selection. 465 class SCEVSMaxExpr : public SCEVMinMaxExpr { 466 friend class ScalarEvolution; 467 468 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 469 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {} 470 471 public: 472 /// Methods for support type inquiry through isa, cast, and dyn_cast: 473 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; } 474 }; 475 476 /// This class represents an unsigned maximum selection. 477 class SCEVUMaxExpr : public SCEVMinMaxExpr { 478 friend class ScalarEvolution; 479 480 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 481 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {} 482 483 public: 484 /// Methods for support type inquiry through isa, cast, and dyn_cast: 485 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; } 486 }; 487 488 /// This class represents a signed minimum selection. 489 class SCEVSMinExpr : public SCEVMinMaxExpr { 490 friend class ScalarEvolution; 491 492 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 493 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {} 494 495 public: 496 /// Methods for support type inquiry through isa, cast, and dyn_cast: 497 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; } 498 }; 499 500 /// This class represents an unsigned minimum selection. 501 class SCEVUMinExpr : public SCEVMinMaxExpr { 502 friend class ScalarEvolution; 503 504 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 505 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {} 506 507 public: 508 /// Methods for support type inquiry through isa, cast, and dyn_cast: 509 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; } 510 }; 511 512 /// This node is the base class for sequential/in-order min/max selections. 513 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they 514 /// are early-returning upon reaching saturation point. 515 /// I.e. given `0 umin_seq poison`, the result will be `0`, 516 /// while the result of `0 umin poison` is `poison`. 517 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { 518 friend class ScalarEvolution; 519 520 static bool isSequentialMinMaxType(enum SCEVTypes T) { 521 return T == scSequentialUMinExpr; 522 } 523 524 /// Set flags for a non-recurrence without clearing previously set flags. 525 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 526 527 protected: 528 /// Note: Constructing subclasses via this constructor is allowed 529 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 530 const SCEV *const *O, size_t N) 531 : SCEVNAryExpr(ID, T, O, N) { 532 assert(isSequentialMinMaxType(T)); 533 // Min and max never overflow 534 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 535 } 536 537 public: 538 Type *getType() const { return getOperand(0)->getType(); } 539 540 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) { 541 assert(isSequentialMinMaxType(Ty)); 542 switch (Ty) { 543 case scSequentialUMinExpr: 544 return scUMinExpr; 545 default: 546 llvm_unreachable("Not a sequential min/max type."); 547 } 548 } 549 550 SCEVTypes getEquivalentNonSequentialSCEVType() const { 551 return getEquivalentNonSequentialSCEVType(getSCEVType()); 552 } 553 554 static bool classof(const SCEV *S) { 555 return isSequentialMinMaxType(S->getSCEVType()); 556 } 557 }; 558 559 /// This class represents a sequential/in-order unsigned minimum selection. 560 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { 561 friend class ScalarEvolution; 562 563 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, 564 size_t N) 565 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {} 566 567 public: 568 /// Methods for support type inquiry through isa, cast, and dyn_cast: 569 static bool classof(const SCEV *S) { 570 return S->getSCEVType() == scSequentialUMinExpr; 571 } 572 }; 573 574 /// This means that we are dealing with an entirely unknown SCEV 575 /// value, and only represent it as its LLVM Value. This is the 576 /// "bottom" value for the analysis. 577 class SCEVUnknown final : public SCEV, private CallbackVH { 578 friend class ScalarEvolution; 579 580 /// The parent ScalarEvolution value. This is used to update the 581 /// parent's maps when the value associated with a SCEVUnknown is 582 /// deleted or RAUW'd. 583 ScalarEvolution *SE; 584 585 /// The next pointer in the linked list of all SCEVUnknown 586 /// instances owned by a ScalarEvolution. 587 SCEVUnknown *Next; 588 589 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se, 590 SCEVUnknown *next) 591 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {} 592 593 // Implement CallbackVH. 594 void deleted() override; 595 void allUsesReplacedWith(Value *New) override; 596 597 public: 598 Value *getValue() const { return getValPtr(); } 599 600 Type *getType() const { return getValPtr()->getType(); } 601 602 /// Methods for support type inquiry through isa, cast, and dyn_cast: 603 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; } 604 }; 605 606 /// This class defines a simple visitor class that may be used for 607 /// various SCEV analysis purposes. 608 template <typename SC, typename RetVal = void> struct SCEVVisitor { 609 RetVal visit(const SCEV *S) { 610 switch (S->getSCEVType()) { 611 case scConstant: 612 return ((SC *)this)->visitConstant((const SCEVConstant *)S); 613 case scVScale: 614 return ((SC *)this)->visitVScale((const SCEVVScale *)S); 615 case scPtrToInt: 616 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S); 617 case scTruncate: 618 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S); 619 case scZeroExtend: 620 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S); 621 case scSignExtend: 622 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S); 623 case scAddExpr: 624 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S); 625 case scMulExpr: 626 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S); 627 case scUDivExpr: 628 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S); 629 case scAddRecExpr: 630 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S); 631 case scSMaxExpr: 632 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S); 633 case scUMaxExpr: 634 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S); 635 case scSMinExpr: 636 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S); 637 case scUMinExpr: 638 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S); 639 case scSequentialUMinExpr: 640 return ((SC *)this) 641 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S); 642 case scUnknown: 643 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S); 644 case scCouldNotCompute: 645 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S); 646 } 647 llvm_unreachable("Unknown SCEV kind!"); 648 } 649 650 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) { 651 llvm_unreachable("Invalid use of SCEVCouldNotCompute!"); 652 } 653 }; 654 655 /// Visit all nodes in the expression tree using worklist traversal. 656 /// 657 /// Visitor implements: 658 /// // return true to follow this node. 659 /// bool follow(const SCEV *S); 660 /// // return true to terminate the search. 661 /// bool isDone(); 662 template <typename SV> class SCEVTraversal { 663 SV &Visitor; 664 SmallVector<const SCEV *, 8> Worklist; 665 SmallPtrSet<const SCEV *, 8> Visited; 666 667 void push(const SCEV *S) { 668 if (Visited.insert(S).second && Visitor.follow(S)) 669 Worklist.push_back(S); 670 } 671 672 public: 673 SCEVTraversal(SV &V) : Visitor(V) {} 674 675 void visitAll(const SCEV *Root) { 676 push(Root); 677 while (!Worklist.empty() && !Visitor.isDone()) { 678 const SCEV *S = Worklist.pop_back_val(); 679 680 switch (S->getSCEVType()) { 681 case scConstant: 682 case scVScale: 683 case scUnknown: 684 continue; 685 case scPtrToInt: 686 case scTruncate: 687 case scZeroExtend: 688 case scSignExtend: 689 case scAddExpr: 690 case scMulExpr: 691 case scUDivExpr: 692 case scSMaxExpr: 693 case scUMaxExpr: 694 case scSMinExpr: 695 case scUMinExpr: 696 case scSequentialUMinExpr: 697 case scAddRecExpr: 698 for (const auto *Op : S->operands()) { 699 push(Op); 700 if (Visitor.isDone()) 701 break; 702 } 703 continue; 704 case scCouldNotCompute: 705 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 706 } 707 llvm_unreachable("Unknown SCEV kind!"); 708 } 709 } 710 }; 711 712 /// Use SCEVTraversal to visit all nodes in the given expression tree. 713 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) { 714 SCEVTraversal<SV> T(Visitor); 715 T.visitAll(Root); 716 } 717 718 /// Return true if any node in \p Root satisfies the predicate \p Pred. 719 template <typename PredTy> 720 bool SCEVExprContains(const SCEV *Root, PredTy Pred) { 721 struct FindClosure { 722 bool Found = false; 723 PredTy Pred; 724 725 FindClosure(PredTy Pred) : Pred(Pred) {} 726 727 bool follow(const SCEV *S) { 728 if (!Pred(S)) 729 return true; 730 731 Found = true; 732 return false; 733 } 734 735 bool isDone() const { return Found; } 736 }; 737 738 FindClosure FC(Pred); 739 visitAll(Root, FC); 740 return FC.Found; 741 } 742 743 /// This visitor recursively visits a SCEV expression and re-writes it. 744 /// The result from each visit is cached, so it will return the same 745 /// SCEV for the same input. 746 template <typename SC> 747 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> { 748 protected: 749 ScalarEvolution &SE; 750 // Memoize the result of each visit so that we only compute once for 751 // the same input SCEV. This is to avoid redundant computations when 752 // a SCEV is referenced by multiple SCEVs. Without memoization, this 753 // visit algorithm would have exponential time complexity in the worst 754 // case, causing the compiler to hang on certain tests. 755 SmallDenseMap<const SCEV *, const SCEV *> RewriteResults; 756 757 public: 758 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} 759 760 const SCEV *visit(const SCEV *S) { 761 auto It = RewriteResults.find(S); 762 if (It != RewriteResults.end()) 763 return It->second; 764 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S); 765 auto Result = RewriteResults.try_emplace(S, Visited); 766 assert(Result.second && "Should insert a new entry"); 767 return Result.first->second; 768 } 769 770 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } 771 772 const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; } 773 774 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { 775 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 776 return Operand == Expr->getOperand() 777 ? Expr 778 : SE.getPtrToIntExpr(Operand, Expr->getType()); 779 } 780 781 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { 782 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 783 return Operand == Expr->getOperand() 784 ? Expr 785 : SE.getTruncateExpr(Operand, Expr->getType()); 786 } 787 788 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 789 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 790 return Operand == Expr->getOperand() 791 ? Expr 792 : SE.getZeroExtendExpr(Operand, Expr->getType()); 793 } 794 795 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 796 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 797 return Operand == Expr->getOperand() 798 ? Expr 799 : SE.getSignExtendExpr(Operand, Expr->getType()); 800 } 801 802 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { 803 SmallVector<const SCEV *, 2> Operands; 804 bool Changed = false; 805 for (const auto *Op : Expr->operands()) { 806 Operands.push_back(((SC *)this)->visit(Op)); 807 Changed |= Op != Operands.back(); 808 } 809 return !Changed ? Expr : SE.getAddExpr(Operands); 810 } 811 812 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { 813 SmallVector<const SCEV *, 2> Operands; 814 bool Changed = false; 815 for (const auto *Op : Expr->operands()) { 816 Operands.push_back(((SC *)this)->visit(Op)); 817 Changed |= Op != Operands.back(); 818 } 819 return !Changed ? Expr : SE.getMulExpr(Operands); 820 } 821 822 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { 823 auto *LHS = ((SC *)this)->visit(Expr->getLHS()); 824 auto *RHS = ((SC *)this)->visit(Expr->getRHS()); 825 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS(); 826 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS); 827 } 828 829 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 830 SmallVector<const SCEV *, 2> Operands; 831 bool Changed = false; 832 for (const auto *Op : Expr->operands()) { 833 Operands.push_back(((SC *)this)->visit(Op)); 834 Changed |= Op != Operands.back(); 835 } 836 return !Changed ? Expr 837 : SE.getAddRecExpr(Operands, Expr->getLoop(), 838 Expr->getNoWrapFlags()); 839 } 840 841 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { 842 SmallVector<const SCEV *, 2> Operands; 843 bool Changed = false; 844 for (const auto *Op : Expr->operands()) { 845 Operands.push_back(((SC *)this)->visit(Op)); 846 Changed |= Op != Operands.back(); 847 } 848 return !Changed ? Expr : SE.getSMaxExpr(Operands); 849 } 850 851 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { 852 SmallVector<const SCEV *, 2> Operands; 853 bool Changed = false; 854 for (const auto *Op : Expr->operands()) { 855 Operands.push_back(((SC *)this)->visit(Op)); 856 Changed |= Op != Operands.back(); 857 } 858 return !Changed ? Expr : SE.getUMaxExpr(Operands); 859 } 860 861 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) { 862 SmallVector<const SCEV *, 2> Operands; 863 bool Changed = false; 864 for (const auto *Op : Expr->operands()) { 865 Operands.push_back(((SC *)this)->visit(Op)); 866 Changed |= Op != Operands.back(); 867 } 868 return !Changed ? Expr : SE.getSMinExpr(Operands); 869 } 870 871 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) { 872 SmallVector<const SCEV *, 2> Operands; 873 bool Changed = false; 874 for (const auto *Op : Expr->operands()) { 875 Operands.push_back(((SC *)this)->visit(Op)); 876 Changed |= Op != Operands.back(); 877 } 878 return !Changed ? Expr : SE.getUMinExpr(Operands); 879 } 880 881 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { 882 SmallVector<const SCEV *, 2> Operands; 883 bool Changed = false; 884 for (const auto *Op : Expr->operands()) { 885 Operands.push_back(((SC *)this)->visit(Op)); 886 Changed |= Op != Operands.back(); 887 } 888 return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true); 889 } 890 891 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } 892 893 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { 894 return Expr; 895 } 896 }; 897 898 using ValueToValueMap = DenseMap<const Value *, Value *>; 899 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>; 900 901 /// The SCEVParameterRewriter takes a scalar evolution expression and updates 902 /// the SCEVUnknown components following the Map (Value -> SCEV). 903 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> { 904 public: 905 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, 906 ValueToSCEVMapTy &Map) { 907 SCEVParameterRewriter Rewriter(SE, Map); 908 return Rewriter.visit(Scev); 909 } 910 911 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) 912 : SCEVRewriteVisitor(SE), Map(M) {} 913 914 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 915 auto I = Map.find(Expr->getValue()); 916 if (I == Map.end()) 917 return Expr; 918 return I->second; 919 } 920 921 private: 922 ValueToSCEVMapTy ⤅ 923 }; 924 925 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>; 926 927 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies 928 /// the Map (Loop -> SCEV) to all AddRecExprs. 929 class SCEVLoopAddRecRewriter 930 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> { 931 public: 932 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) 933 : SCEVRewriteVisitor(SE), Map(M) {} 934 935 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, 936 ScalarEvolution &SE) { 937 SCEVLoopAddRecRewriter Rewriter(SE, Map); 938 return Rewriter.visit(Scev); 939 } 940 941 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 942 SmallVector<const SCEV *, 2> Operands; 943 for (const SCEV *Op : Expr->operands()) 944 Operands.push_back(visit(Op)); 945 946 const Loop *L = Expr->getLoop(); 947 if (0 == Map.count(L)) 948 return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); 949 950 return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE); 951 } 952 953 private: 954 LoopToScevMapT ⤅ 955 }; 956 957 } // end namespace llvm 958 959 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 960