1 //===- GenericUniformityImpl.h -----------------------*- C++ -*------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This template implementation resides in a separate file so that it 10 // does not get injected into every .cpp file that includes the 11 // generic header. 12 // 13 // DO NOT INCLUDE THIS FILE WHEN MERELY USING UNIFORMITYINFO. 14 // 15 // This file should only be included by files that implement a 16 // specialization of the relvant templates. Currently these are: 17 // - UniformityAnalysis.cpp 18 // 19 // Note: The DEBUG_TYPE macro should be defined before using this 20 // file so that any use of LLVM_DEBUG is associated with the 21 // including file rather than this file. 22 // 23 //===----------------------------------------------------------------------===// 24 /// 25 /// \file 26 /// \brief Implementation of uniformity analysis. 27 /// 28 /// The algorithm is a fixed point iteration that starts with the assumption 29 /// that all control flow and all values are uniform. Starting from sources of 30 /// divergence (whose discovery must be implemented by a CFG- or even 31 /// target-specific derived class), divergence of values is propagated from 32 /// definition to uses in a straight-forward way. The main complexity lies in 33 /// the propagation of the impact of divergent control flow on the divergence of 34 /// values (sync dependencies). 35 /// 36 //===----------------------------------------------------------------------===// 37 38 #ifndef LLVM_ADT_GENERICUNIFORMITYIMPL_H 39 #define LLVM_ADT_GENERICUNIFORMITYIMPL_H 40 41 #include "llvm/ADT/GenericUniformityInfo.h" 42 43 #include "llvm/ADT/SmallPtrSet.h" 44 #include "llvm/ADT/SparseBitVector.h" 45 #include "llvm/ADT/StringExtras.h" 46 #include "llvm/Support/raw_ostream.h" 47 48 #include <set> 49 50 #define DEBUG_TYPE "uniformity" 51 52 using namespace llvm; 53 54 namespace llvm { 55 56 template <typename Range> auto unique(Range &&R) { 57 return std::unique(adl_begin(R), adl_end(R)); 58 } 59 60 /// Construct a specially modified post-order traversal of cycles. 61 /// 62 /// The ModifiedPO is contructed using a virtually modified CFG as follows: 63 /// 64 /// 1. The successors of pre-entry nodes (predecessors of an cycle 65 /// entry that are outside the cycle) are replaced by the 66 /// successors of the successors of the header. 67 /// 2. Successors of the cycle header are replaced by the exit blocks 68 /// of the cycle. 69 /// 70 /// Effectively, we produce a depth-first numbering with the following 71 /// properties: 72 /// 73 /// 1. Nodes after a cycle are numbered earlier than the cycle header. 74 /// 2. The header is numbered earlier than the nodes in the cycle. 75 /// 3. The numbering of the nodes within the cycle forms an interval 76 /// starting with the header. 77 /// 78 /// Effectively, the virtual modification arranges the nodes in a 79 /// cycle as a DAG with the header as the sole leaf, and successors of 80 /// the header as the roots. A reverse traversal of this numbering has 81 /// the following invariant on the unmodified original CFG: 82 /// 83 /// Each node is visited after all its predecessors, except if that 84 /// predecessor is the cycle header. 85 /// 86 template <typename ContextT> class ModifiedPostOrder { 87 public: 88 using BlockT = typename ContextT::BlockT; 89 using FunctionT = typename ContextT::FunctionT; 90 using DominatorTreeT = typename ContextT::DominatorTreeT; 91 92 using CycleInfoT = GenericCycleInfo<ContextT>; 93 using CycleT = typename CycleInfoT::CycleT; 94 using const_iterator = typename std::vector<BlockT *>::const_iterator; 95 96 ModifiedPostOrder(const ContextT &C) : Context(C) {} 97 98 bool empty() const { return m_order.empty(); } 99 size_t size() const { return m_order.size(); } 100 101 void clear() { m_order.clear(); } 102 void compute(const CycleInfoT &CI); 103 104 unsigned count(BlockT *BB) const { return POIndex.count(BB); } 105 const BlockT *operator[](size_t idx) const { return m_order[idx]; } 106 107 void appendBlock(const BlockT &BB, bool isReducibleCycleHeader = false) { 108 POIndex[&BB] = m_order.size(); 109 m_order.push_back(&BB); 110 LLVM_DEBUG(dbgs() << "ModifiedPO(" << POIndex[&BB] 111 << "): " << Context.print(&BB) << "\n"); 112 if (isReducibleCycleHeader) 113 ReducibleCycleHeaders.insert(&BB); 114 } 115 116 unsigned getIndex(const BlockT *BB) const { 117 assert(POIndex.count(BB)); 118 return POIndex.lookup(BB); 119 } 120 121 bool isReducibleCycleHeader(const BlockT *BB) const { 122 return ReducibleCycleHeaders.contains(BB); 123 } 124 125 private: 126 SmallVector<const BlockT *> m_order; 127 DenseMap<const BlockT *, unsigned> POIndex; 128 SmallPtrSet<const BlockT *, 32> ReducibleCycleHeaders; 129 const ContextT &Context; 130 131 void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle, 132 SmallPtrSetImpl<BlockT *> &Finalized); 133 134 void computeStackPO(SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI, 135 const CycleT *Cycle, 136 SmallPtrSetImpl<BlockT *> &Finalized); 137 }; 138 139 template <typename> class DivergencePropagator; 140 141 /// \class GenericSyncDependenceAnalysis 142 /// 143 /// \brief Locate join blocks for disjoint paths starting at a divergent branch. 144 /// 145 /// An analysis per divergent branch that returns the set of basic 146 /// blocks whose phi nodes become divergent due to divergent control. 147 /// These are the blocks that are reachable by two disjoint paths from 148 /// the branch, or cycle exits reachable along a path that is disjoint 149 /// from a path to the cycle latch. 150 151 // --- Above line is not a doxygen comment; intentionally left blank --- 152 // 153 // Originally implemented in SyncDependenceAnalysis.cpp for DivergenceAnalysis. 154 // 155 // The SyncDependenceAnalysis is used in the UniformityAnalysis to model 156 // control-induced divergence in phi nodes. 157 // 158 // -- Reference -- 159 // The algorithm is an extension of Section 5 of 160 // 161 // An abstract interpretation for SPMD divergence 162 // on reducible control flow graphs. 163 // Julian Rosemann, Simon Moll and Sebastian Hack 164 // POPL '21 165 // 166 // 167 // -- Sync dependence -- 168 // Sync dependence characterizes the control flow aspect of the 169 // propagation of branch divergence. For example, 170 // 171 // %cond = icmp slt i32 %tid, 10 172 // br i1 %cond, label %then, label %else 173 // then: 174 // br label %merge 175 // else: 176 // br label %merge 177 // merge: 178 // %a = phi i32 [ 0, %then ], [ 1, %else ] 179 // 180 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 181 // because %tid is not on its use-def chains, %a is sync dependent on %tid 182 // because the branch "br i1 %cond" depends on %tid and affects which value %a 183 // is assigned to. 184 // 185 // 186 // -- Reduction to SSA construction -- 187 // There are two disjoint paths from A to X, if a certain variant of SSA 188 // construction places a phi node in X under the following set-up scheme. 189 // 190 // This variant of SSA construction ignores incoming undef values. 191 // That is paths from the entry without a definition do not result in 192 // phi nodes. 193 // 194 // entry 195 // / \ 196 // A \ 197 // / \ Y 198 // B C / 199 // \ / \ / 200 // D E 201 // \ / 202 // F 203 // 204 // Assume that A contains a divergent branch. We are interested 205 // in the set of all blocks where each block is reachable from A 206 // via two disjoint paths. This would be the set {D, F} in this 207 // case. 208 // To generally reduce this query to SSA construction we introduce 209 // a virtual variable x and assign to x different values in each 210 // successor block of A. 211 // 212 // entry 213 // / \ 214 // A \ 215 // / \ Y 216 // x = 0 x = 1 / 217 // \ / \ / 218 // D E 219 // \ / 220 // F 221 // 222 // Our flavor of SSA construction for x will construct the following 223 // 224 // entry 225 // / \ 226 // A \ 227 // / \ Y 228 // x0 = 0 x1 = 1 / 229 // \ / \ / 230 // x2 = phi E 231 // \ / 232 // x3 = phi 233 // 234 // The blocks D and F contain phi nodes and are thus each reachable 235 // by two disjoins paths from A. 236 // 237 // -- Remarks -- 238 // * In case of cycle exits we need to check for temporal divergence. 239 // To this end, we check whether the definition of x differs between the 240 // cycle exit and the cycle header (_after_ SSA construction). 241 // 242 // * In the presence of irreducible control flow, the fixed point is 243 // reached only after multiple iterations. This is because labels 244 // reaching the header of a cycle must be repropagated through the 245 // cycle. This is true even in a reducible cycle, since the labels 246 // may have been produced by a nested irreducible cycle. 247 // 248 // * Note that SyncDependenceAnalysis is not concerned with the points 249 // of convergence in an irreducible cycle. It's only purpose is to 250 // identify join blocks. The "diverged entry" criterion is 251 // separately applied on join blocks to determine if an entire 252 // irreducible cycle is assumed to be divergent. 253 // 254 // * Relevant related work: 255 // A simple algorithm for global data flow analysis problems. 256 // Matthew S. Hecht and Jeffrey D. Ullman. 257 // SIAM Journal on Computing, 4(4):519–532, December 1975. 258 // 259 template <typename ContextT> class GenericSyncDependenceAnalysis { 260 public: 261 using BlockT = typename ContextT::BlockT; 262 using DominatorTreeT = typename ContextT::DominatorTreeT; 263 using FunctionT = typename ContextT::FunctionT; 264 using ValueRefT = typename ContextT::ValueRefT; 265 using InstructionT = typename ContextT::InstructionT; 266 267 using CycleInfoT = GenericCycleInfo<ContextT>; 268 using CycleT = typename CycleInfoT::CycleT; 269 270 using ConstBlockSet = SmallPtrSet<const BlockT *, 4>; 271 using ModifiedPO = ModifiedPostOrder<ContextT>; 272 273 // * if BlockLabels[B] == C then C is the dominating definition at 274 // block B 275 // * if BlockLabels[B] == nullptr then we haven't seen B yet 276 // * if BlockLabels[B] == B then: 277 // - B is a join point of disjoint paths from X, or, 278 // - B is an immediate successor of X (initial value), or, 279 // - B is X 280 using BlockLabelMap = DenseMap<const BlockT *, const BlockT *>; 281 282 /// Information discovered by the sync dependence analysis for each 283 /// divergent branch. 284 struct DivergenceDescriptor { 285 // Join points of diverged paths. 286 ConstBlockSet JoinDivBlocks; 287 // Divergent cycle exits 288 ConstBlockSet CycleDivBlocks; 289 // Labels assigned to blocks on diverged paths. 290 BlockLabelMap BlockLabels; 291 }; 292 293 using DivergencePropagatorT = DivergencePropagator<ContextT>; 294 295 GenericSyncDependenceAnalysis(const ContextT &Context, 296 const DominatorTreeT &DT, const CycleInfoT &CI); 297 298 /// \brief Computes divergent join points and cycle exits caused by branch 299 /// divergence in \p Term. 300 /// 301 /// This returns a pair of sets: 302 /// * The set of blocks which are reachable by disjoint paths from 303 /// \p Term. 304 /// * The set also contains cycle exits if there two disjoint paths: 305 /// one from \p Term to the cycle exit and another from \p Term to 306 /// the cycle header. 307 const DivergenceDescriptor &getJoinBlocks(const BlockT *DivTermBlock); 308 309 private: 310 static DivergenceDescriptor EmptyDivergenceDesc; 311 312 ModifiedPO CyclePO; 313 314 const DominatorTreeT &DT; 315 const CycleInfoT &CI; 316 317 DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>> 318 CachedControlDivDescs; 319 }; 320 321 /// \brief Analysis that identifies uniform values in a data-parallel 322 /// execution. 323 /// 324 /// This analysis propagates divergence in a data-parallel context 325 /// from sources of divergence to all users. It can be instantiated 326 /// for an IR that provides a suitable SSAContext. 327 template <typename ContextT> class GenericUniformityAnalysisImpl { 328 public: 329 using BlockT = typename ContextT::BlockT; 330 using FunctionT = typename ContextT::FunctionT; 331 using ValueRefT = typename ContextT::ValueRefT; 332 using ConstValueRefT = typename ContextT::ConstValueRefT; 333 using UseT = typename ContextT::UseT; 334 using InstructionT = typename ContextT::InstructionT; 335 using DominatorTreeT = typename ContextT::DominatorTreeT; 336 337 using CycleInfoT = GenericCycleInfo<ContextT>; 338 using CycleT = typename CycleInfoT::CycleT; 339 340 using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>; 341 using DivergenceDescriptorT = 342 typename SyncDependenceAnalysisT::DivergenceDescriptor; 343 using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap; 344 345 GenericUniformityAnalysisImpl(const FunctionT &F, const DominatorTreeT &DT, 346 const CycleInfoT &CI, 347 const TargetTransformInfo *TTI) 348 : Context(CI.getSSAContext()), F(F), CI(CI), TTI(TTI), DT(DT), 349 SDA(Context, DT, CI) {} 350 351 void initialize(); 352 353 const FunctionT &getFunction() const { return F; } 354 355 /// \brief Mark \p UniVal as a value that is always uniform. 356 void addUniformOverride(const InstructionT &Instr); 357 358 /// \brief Examine \p I for divergent outputs and add to the worklist. 359 void markDivergent(const InstructionT &I); 360 361 /// \brief Mark \p DivVal as a divergent value. 362 /// \returns Whether the tracked divergence state of \p DivVal changed. 363 bool markDivergent(ConstValueRefT DivVal); 364 365 /// \brief Mark outputs of \p Instr as divergent. 366 /// \returns Whether the tracked divergence state of any output has changed. 367 bool markDefsDivergent(const InstructionT &Instr); 368 369 /// \brief Propagate divergence to all instructions in the region. 370 /// Divergence is seeded by calls to \p markDivergent. 371 void compute(); 372 373 /// \brief Whether any value was marked or analyzed to be divergent. 374 bool hasDivergence() const { return !DivergentValues.empty(); } 375 376 /// \brief Whether \p Val will always return a uniform value regardless of its 377 /// operands 378 bool isAlwaysUniform(const InstructionT &Instr) const; 379 380 bool hasDivergentDefs(const InstructionT &I) const; 381 382 bool isDivergent(const InstructionT &I) const { 383 if (I.isTerminator()) { 384 return DivergentTermBlocks.contains(I.getParent()); 385 } 386 return hasDivergentDefs(I); 387 }; 388 389 /// \brief Whether \p Val is divergent at its definition. 390 bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); } 391 392 bool isDivergentUse(const UseT &U) const; 393 394 bool hasDivergentTerminator(const BlockT &B) const { 395 return DivergentTermBlocks.contains(&B); 396 } 397 398 void print(raw_ostream &out) const; 399 400 protected: 401 /// \brief Value/block pair representing a single phi input. 402 struct PhiInput { 403 ConstValueRefT value; 404 BlockT *predBlock; 405 406 PhiInput(ConstValueRefT value, BlockT *predBlock) 407 : value(value), predBlock(predBlock) {} 408 }; 409 410 const ContextT &Context; 411 const FunctionT &F; 412 const CycleInfoT &CI; 413 const TargetTransformInfo *TTI = nullptr; 414 415 // Detected/marked divergent values. 416 std::set<ConstValueRefT> DivergentValues; 417 SmallPtrSet<const BlockT *, 32> DivergentTermBlocks; 418 419 // Internal worklist for divergence propagation. 420 std::vector<const InstructionT *> Worklist; 421 422 /// \brief Mark \p Term as divergent and push all Instructions that become 423 /// divergent as a result on the worklist. 424 void analyzeControlDivergence(const InstructionT &Term); 425 426 private: 427 const DominatorTreeT &DT; 428 429 // Recognized cycles with divergent exits. 430 SmallPtrSet<const CycleT *, 16> DivergentExitCycles; 431 432 // Cycles assumed to be divergent. 433 // 434 // We don't use a set here because every insertion needs an explicit 435 // traversal of all existing members. 436 SmallVector<const CycleT *> AssumedDivergent; 437 438 // The SDA links divergent branches to divergent control-flow joins. 439 SyncDependenceAnalysisT SDA; 440 441 // Set of known-uniform values. 442 SmallPtrSet<const InstructionT *, 32> UniformOverrides; 443 444 /// \brief Mark all nodes in \p JoinBlock as divergent and push them on 445 /// the worklist. 446 void taintAndPushAllDefs(const BlockT &JoinBlock); 447 448 /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on 449 /// the worklist. 450 void taintAndPushPhiNodes(const BlockT &JoinBlock); 451 452 /// \brief Identify all Instructions that become divergent because \p DivExit 453 /// is a divergent cycle exit of \p DivCycle. Mark those instructions as 454 /// divergent and push them on the worklist. 455 void propagateCycleExitDivergence(const BlockT &DivExit, 456 const CycleT &DivCycle); 457 458 /// Mark as divergent all external uses of values defined in \p DefCycle. 459 void analyzeCycleExitDivergence(const CycleT &DefCycle); 460 461 /// \brief Mark as divergent all uses of \p I that are outside \p DefCycle. 462 void propagateTemporalDivergence(const InstructionT &I, 463 const CycleT &DefCycle); 464 465 /// \brief Push all users of \p Val (in the region) to the worklist. 466 void pushUsers(const InstructionT &I); 467 void pushUsers(ConstValueRefT V); 468 469 bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const; 470 471 /// \brief Whether \p Def is divergent when read in \p ObservingBlock. 472 bool isTemporalDivergent(const BlockT &ObservingBlock, 473 const InstructionT &Def) const; 474 }; 475 476 template <typename ImplT> 477 void GenericUniformityAnalysisImplDeleter<ImplT>::operator()(ImplT *Impl) { 478 delete Impl; 479 } 480 481 /// Compute divergence starting with a divergent branch. 482 template <typename ContextT> class DivergencePropagator { 483 public: 484 using BlockT = typename ContextT::BlockT; 485 using DominatorTreeT = typename ContextT::DominatorTreeT; 486 using FunctionT = typename ContextT::FunctionT; 487 using ValueRefT = typename ContextT::ValueRefT; 488 489 using CycleInfoT = GenericCycleInfo<ContextT>; 490 using CycleT = typename CycleInfoT::CycleT; 491 492 using ModifiedPO = ModifiedPostOrder<ContextT>; 493 using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>; 494 using DivergenceDescriptorT = 495 typename SyncDependenceAnalysisT::DivergenceDescriptor; 496 using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap; 497 498 const ModifiedPO &CyclePOT; 499 const DominatorTreeT &DT; 500 const CycleInfoT &CI; 501 const BlockT &DivTermBlock; 502 const ContextT &Context; 503 504 // Track blocks that receive a new label. Every time we relabel a 505 // cycle header, we another pass over the modified post-order in 506 // order to propagate the header label. The bit vector also allows 507 // us to skip labels that have not changed. 508 SparseBitVector<> FreshLabels; 509 510 // divergent join and cycle exit descriptor. 511 std::unique_ptr<DivergenceDescriptorT> DivDesc; 512 BlockLabelMapT &BlockLabels; 513 514 DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT, 515 const CycleInfoT &CI, const BlockT &DivTermBlock) 516 : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock), 517 Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT), 518 BlockLabels(DivDesc->BlockLabels) {} 519 520 void printDefs(raw_ostream &Out) { 521 Out << "Propagator::BlockLabels {\n"; 522 for (int BlockIdx = (int)CyclePOT.size() - 1; BlockIdx >= 0; --BlockIdx) { 523 const auto *Block = CyclePOT[BlockIdx]; 524 const auto *Label = BlockLabels[Block]; 525 Out << Context.print(Block) << "(" << BlockIdx << ") : "; 526 if (!Label) { 527 Out << "<null>\n"; 528 } else { 529 Out << Context.print(Label) << "\n"; 530 } 531 } 532 Out << "}\n"; 533 } 534 535 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this 536 // causes a divergent join. 537 bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) { 538 const auto *OldLabel = BlockLabels[&SuccBlock]; 539 540 LLVM_DEBUG(dbgs() << "labeling " << Context.print(&SuccBlock) << ":\n" 541 << "\tpushed label: " << Context.print(&PushedLabel) 542 << "\n" 543 << "\told label: " << Context.print(OldLabel) << "\n"); 544 545 // Early exit if there is no change in the label. 546 if (OldLabel == &PushedLabel) 547 return false; 548 549 if (OldLabel != &SuccBlock) { 550 auto SuccIdx = CyclePOT.getIndex(&SuccBlock); 551 // Assigning a new label, mark this in FreshLabels. 552 LLVM_DEBUG(dbgs() << "\tfresh label: " << SuccIdx << "\n"); 553 FreshLabels.set(SuccIdx); 554 } 555 556 // This is not a join if the succ was previously unlabeled. 557 if (!OldLabel) { 558 LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&PushedLabel) 559 << "\n"); 560 BlockLabels[&SuccBlock] = &PushedLabel; 561 return false; 562 } 563 564 // This is a new join. Label the join block as itself, and not as 565 // the pushed label. 566 LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&SuccBlock) << "\n"); 567 BlockLabels[&SuccBlock] = &SuccBlock; 568 569 return true; 570 } 571 572 // visiting a virtual cycle exit edge from the cycle header --> temporal 573 // divergence on join 574 bool visitCycleExitEdge(const BlockT &ExitBlock, const BlockT &Label) { 575 if (!computeJoin(ExitBlock, Label)) 576 return false; 577 578 // Identified a divergent cycle exit 579 DivDesc->CycleDivBlocks.insert(&ExitBlock); 580 LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(&ExitBlock) 581 << "\n"); 582 return true; 583 } 584 585 // process \p SuccBlock with reaching definition \p Label 586 bool visitEdge(const BlockT &SuccBlock, const BlockT &Label) { 587 if (!computeJoin(SuccBlock, Label)) 588 return false; 589 590 // Divergent, disjoint paths join. 591 DivDesc->JoinDivBlocks.insert(&SuccBlock); 592 LLVM_DEBUG(dbgs() << "\tDivergent join: " << Context.print(&SuccBlock) 593 << "\n"); 594 return true; 595 } 596 597 std::unique_ptr<DivergenceDescriptorT> computeJoinPoints() { 598 assert(DivDesc); 599 600 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " 601 << Context.print(&DivTermBlock) << "\n"); 602 603 // Early stopping criterion 604 int FloorIdx = CyclePOT.size() - 1; 605 const BlockT *FloorLabel = nullptr; 606 int DivTermIdx = CyclePOT.getIndex(&DivTermBlock); 607 608 // Bootstrap with branch targets 609 auto const *DivTermCycle = CI.getCycle(&DivTermBlock); 610 for (const auto *SuccBlock : successors(&DivTermBlock)) { 611 if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) { 612 // If DivTerm exits the cycle immediately, computeJoin() might 613 // not reach SuccBlock with a different label. We need to 614 // check for this exit now. 615 DivDesc->CycleDivBlocks.insert(SuccBlock); 616 LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: " 617 << Context.print(SuccBlock) << "\n"); 618 } 619 auto SuccIdx = CyclePOT.getIndex(SuccBlock); 620 visitEdge(*SuccBlock, *SuccBlock); 621 FloorIdx = std::min<int>(FloorIdx, SuccIdx); 622 } 623 624 while (true) { 625 auto BlockIdx = FreshLabels.find_last(); 626 if (BlockIdx == -1 || BlockIdx < FloorIdx) 627 break; 628 629 LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs())); 630 631 FreshLabels.reset(BlockIdx); 632 if (BlockIdx == DivTermIdx) { 633 LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n"); 634 continue; 635 } 636 637 const auto *Block = CyclePOT[BlockIdx]; 638 LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index " 639 << BlockIdx << "\n"); 640 641 const auto *Label = BlockLabels[Block]; 642 assert(Label); 643 644 bool CausedJoin = false; 645 int LoweredFloorIdx = FloorIdx; 646 647 // If the current block is the header of a reducible cycle that 648 // contains the divergent branch, then the label should be 649 // propagated to the cycle exits. Such a header is the "last 650 // possible join" of any disjoint paths within this cycle. This 651 // prevents detection of spurious joins at the entries of any 652 // irreducible child cycles. 653 // 654 // This conclusion about the header is true for any choice of DFS: 655 // 656 // If some DFS has a reducible cycle C with header H, then for 657 // any other DFS, H is the header of a cycle C' that is a 658 // superset of C. For a divergent branch inside the subgraph 659 // C, any join node inside C is either H, or some node 660 // encountered without passing through H. 661 // 662 auto getReducibleParent = [&](const BlockT *Block) -> const CycleT * { 663 if (!CyclePOT.isReducibleCycleHeader(Block)) 664 return nullptr; 665 const auto *BlockCycle = CI.getCycle(Block); 666 if (BlockCycle->contains(&DivTermBlock)) 667 return BlockCycle; 668 return nullptr; 669 }; 670 671 if (const auto *BlockCycle = getReducibleParent(Block)) { 672 SmallVector<BlockT *, 4> BlockCycleExits; 673 BlockCycle->getExitBlocks(BlockCycleExits); 674 for (auto *BlockCycleExit : BlockCycleExits) { 675 CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label); 676 LoweredFloorIdx = 677 std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit)); 678 } 679 } else { 680 for (const auto *SuccBlock : successors(Block)) { 681 CausedJoin |= visitEdge(*SuccBlock, *Label); 682 LoweredFloorIdx = 683 std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock)); 684 } 685 } 686 687 // Floor update 688 if (CausedJoin) { 689 // 1. Different labels pushed to successors 690 FloorIdx = LoweredFloorIdx; 691 } else if (FloorLabel != Label) { 692 // 2. No join caused BUT we pushed a label that is different than the 693 // last pushed label 694 FloorIdx = LoweredFloorIdx; 695 FloorLabel = Label; 696 } 697 } 698 699 LLVM_DEBUG(dbgs() << "Final labeling:\n"; printDefs(dbgs())); 700 701 // Check every cycle containing DivTermBlock for exit divergence. 702 // A cycle has exit divergence if the label of an exit block does 703 // not match the label of its header. 704 for (const auto *Cycle = CI.getCycle(&DivTermBlock); Cycle; 705 Cycle = Cycle->getParentCycle()) { 706 if (Cycle->isReducible()) { 707 // The exit divergence of a reducible cycle is recorded while 708 // propagating labels. 709 continue; 710 } 711 SmallVector<BlockT *> Exits; 712 Cycle->getExitBlocks(Exits); 713 auto *Header = Cycle->getHeader(); 714 auto *HeaderLabel = BlockLabels[Header]; 715 for (const auto *Exit : Exits) { 716 if (BlockLabels[Exit] != HeaderLabel) { 717 // Identified a divergent cycle exit 718 DivDesc->CycleDivBlocks.insert(Exit); 719 LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(Exit) 720 << "\n"); 721 } 722 } 723 } 724 725 return std::move(DivDesc); 726 } 727 }; 728 729 template <typename ContextT> 730 typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor 731 llvm::GenericSyncDependenceAnalysis<ContextT>::EmptyDivergenceDesc; 732 733 template <typename ContextT> 734 llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis( 735 const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI) 736 : CyclePO(Context), DT(DT), CI(CI) { 737 CyclePO.compute(CI); 738 } 739 740 template <typename ContextT> 741 auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks( 742 const BlockT *DivTermBlock) -> const DivergenceDescriptor & { 743 // trivial case 744 if (succ_size(DivTermBlock) <= 1) { 745 return EmptyDivergenceDesc; 746 } 747 748 // already available in cache? 749 auto ItCached = CachedControlDivDescs.find(DivTermBlock); 750 if (ItCached != CachedControlDivDescs.end()) 751 return *ItCached->second; 752 753 // compute all join points 754 DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock); 755 auto DivDesc = Propagator.computeJoinPoints(); 756 757 auto printBlockSet = [&](ConstBlockSet &Blocks) { 758 return Printable([&](raw_ostream &Out) { 759 Out << "["; 760 ListSeparator LS; 761 for (const auto *BB : Blocks) { 762 Out << LS << CI.getSSAContext().print(BB); 763 } 764 Out << "]\n"; 765 }); 766 }; 767 768 LLVM_DEBUG( 769 dbgs() << "\nResult (" << CI.getSSAContext().print(DivTermBlock) 770 << "):\n JoinDivBlocks: " << printBlockSet(DivDesc->JoinDivBlocks) 771 << " CycleDivBlocks: " << printBlockSet(DivDesc->CycleDivBlocks) 772 << "\n"); 773 (void)printBlockSet; 774 775 auto ItInserted = 776 CachedControlDivDescs.try_emplace(DivTermBlock, std::move(DivDesc)); 777 assert(ItInserted.second); 778 return *ItInserted.first->second; 779 } 780 781 template <typename ContextT> 782 void GenericUniformityAnalysisImpl<ContextT>::markDivergent( 783 const InstructionT &I) { 784 if (isAlwaysUniform(I)) 785 return; 786 bool Marked = false; 787 if (I.isTerminator()) { 788 Marked = DivergentTermBlocks.insert(I.getParent()).second; 789 if (Marked) { 790 LLVM_DEBUG(dbgs() << "marked divergent term block: " 791 << Context.print(I.getParent()) << "\n"); 792 } 793 } else { 794 Marked = markDefsDivergent(I); 795 } 796 797 if (Marked) 798 Worklist.push_back(&I); 799 } 800 801 template <typename ContextT> 802 bool GenericUniformityAnalysisImpl<ContextT>::markDivergent( 803 ConstValueRefT Val) { 804 if (DivergentValues.insert(Val).second) { 805 LLVM_DEBUG(dbgs() << "marked divergent: " << Context.print(Val) << "\n"); 806 return true; 807 } 808 return false; 809 } 810 811 template <typename ContextT> 812 void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride( 813 const InstructionT &Instr) { 814 UniformOverrides.insert(&Instr); 815 } 816 817 // Mark as divergent all external uses of values defined in \p DefCycle. 818 // 819 // A value V defined by a block B inside \p DefCycle may be used outside the 820 // cycle only if the use is a PHI in some exit block, or B dominates some exit 821 // block. Thus, we check uses as follows: 822 // 823 // - Check all PHIs in all exit blocks for inputs defined inside \p DefCycle. 824 // - For every block B inside \p DefCycle that dominates at least one exit 825 // block, check all uses outside \p DefCycle. 826 // 827 // FIXME: This function does not distinguish between divergent and uniform 828 // exits. For each divergent exit, only the values that are live at that exit 829 // need to be propagated as divergent at their use outside the cycle. 830 template <typename ContextT> 831 void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence( 832 const CycleT &DefCycle) { 833 SmallVector<BlockT *> Exits; 834 DefCycle.getExitBlocks(Exits); 835 for (auto *Exit : Exits) { 836 for (auto &Phi : Exit->phis()) { 837 if (usesValueFromCycle(Phi, DefCycle)) { 838 markDivergent(Phi); 839 } 840 } 841 } 842 843 for (auto *BB : DefCycle.blocks()) { 844 if (!llvm::any_of(Exits, 845 [&](BlockT *Exit) { return DT.dominates(BB, Exit); })) 846 continue; 847 for (auto &II : *BB) { 848 propagateTemporalDivergence(II, DefCycle); 849 } 850 } 851 } 852 853 template <typename ContextT> 854 void GenericUniformityAnalysisImpl<ContextT>::propagateCycleExitDivergence( 855 const BlockT &DivExit, const CycleT &InnerDivCycle) { 856 LLVM_DEBUG(dbgs() << "\tpropCycleExitDiv " << Context.print(&DivExit) 857 << "\n"); 858 auto *DivCycle = &InnerDivCycle; 859 auto *OuterDivCycle = DivCycle; 860 auto *ExitLevelCycle = CI.getCycle(&DivExit); 861 const unsigned CycleExitDepth = 862 ExitLevelCycle ? ExitLevelCycle->getDepth() : 0; 863 864 // Find outer-most cycle that does not contain \p DivExit 865 while (DivCycle && DivCycle->getDepth() > CycleExitDepth) { 866 LLVM_DEBUG(dbgs() << " Found exiting cycle: " 867 << Context.print(DivCycle->getHeader()) << "\n"); 868 OuterDivCycle = DivCycle; 869 DivCycle = DivCycle->getParentCycle(); 870 } 871 LLVM_DEBUG(dbgs() << "\tOuter-most exiting cycle: " 872 << Context.print(OuterDivCycle->getHeader()) << "\n"); 873 874 if (!DivergentExitCycles.insert(OuterDivCycle).second) 875 return; 876 877 // Exit divergence does not matter if the cycle itself is assumed to 878 // be divergent. 879 for (const auto *C : AssumedDivergent) { 880 if (C->contains(OuterDivCycle)) 881 return; 882 } 883 884 analyzeCycleExitDivergence(*OuterDivCycle); 885 } 886 887 template <typename ContextT> 888 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs( 889 const BlockT &BB) { 890 LLVM_DEBUG(dbgs() << "taintAndPushAllDefs " << Context.print(&BB) << "\n"); 891 for (const auto &I : instrs(BB)) { 892 // Terminators do not produce values; they are divergent only if 893 // the condition is divergent. That is handled when the divergent 894 // condition is placed in the worklist. 895 if (I.isTerminator()) 896 break; 897 898 markDivergent(I); 899 } 900 } 901 902 /// Mark divergent phi nodes in a join block 903 template <typename ContextT> 904 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes( 905 const BlockT &JoinBlock) { 906 LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << Context.print(&JoinBlock) 907 << "\n"); 908 for (const auto &Phi : JoinBlock.phis()) { 909 // FIXME: The non-undef value is not constant per se; it just happens to be 910 // uniform and may not dominate this PHI. So assuming that the same value 911 // reaches along all incoming edges may itself be undefined behaviour. This 912 // particular interpretation of the undef value was added to 913 // DivergenceAnalysis in the following review: 914 // 915 // https://reviews.llvm.org/D19013 916 if (ContextT::isConstantOrUndefValuePhi(Phi)) 917 continue; 918 markDivergent(Phi); 919 } 920 } 921 922 /// Add \p Candidate to \p Cycles if it is not already contained in \p Cycles. 923 /// 924 /// \return true iff \p Candidate was added to \p Cycles. 925 template <typename CycleT> 926 static bool insertIfNotContained(SmallVector<CycleT *> &Cycles, 927 CycleT *Candidate) { 928 if (llvm::any_of(Cycles, 929 [Candidate](CycleT *C) { return C->contains(Candidate); })) 930 return false; 931 Cycles.push_back(Candidate); 932 return true; 933 } 934 935 /// Return the outermost cycle made divergent by branch outside it. 936 /// 937 /// If two paths that diverged outside an irreducible cycle join 938 /// inside that cycle, then that whole cycle is assumed to be 939 /// divergent. This does not apply if the cycle is reducible. 940 template <typename CycleT, typename BlockT> 941 static const CycleT *getExtDivCycle(const CycleT *Cycle, 942 const BlockT *DivTermBlock, 943 const BlockT *JoinBlock) { 944 assert(Cycle); 945 assert(Cycle->contains(JoinBlock)); 946 947 if (Cycle->contains(DivTermBlock)) 948 return nullptr; 949 950 if (Cycle->isReducible()) { 951 assert(Cycle->getHeader() == JoinBlock); 952 return nullptr; 953 } 954 955 const auto *Parent = Cycle->getParentCycle(); 956 while (Parent && !Parent->contains(DivTermBlock)) { 957 // If the join is inside a child, then the parent must be 958 // irreducible. The only join in a reducible cyle is its own 959 // header. 960 assert(!Parent->isReducible()); 961 Cycle = Parent; 962 Parent = Cycle->getParentCycle(); 963 } 964 965 LLVM_DEBUG(dbgs() << "cycle made divergent by external branch\n"); 966 return Cycle; 967 } 968 969 /// Return the outermost cycle made divergent by branch inside it. 970 /// 971 /// This checks the "diverged entry" criterion defined in the 972 /// docs/ConvergenceAnalysis.html. 973 template <typename ContextT, typename CycleT, typename BlockT, 974 typename DominatorTreeT> 975 static const CycleT * 976 getIntDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock, 977 const BlockT *JoinBlock, const DominatorTreeT &DT, 978 ContextT &Context) { 979 LLVM_DEBUG(dbgs() << "examine join " << Context.print(JoinBlock) 980 << "for internal branch " << Context.print(DivTermBlock) 981 << "\n"); 982 if (DT.properlyDominates(DivTermBlock, JoinBlock)) 983 return nullptr; 984 985 // Find the smallest common cycle, if one exists. 986 assert(Cycle && Cycle->contains(JoinBlock)); 987 while (Cycle && !Cycle->contains(DivTermBlock)) { 988 Cycle = Cycle->getParentCycle(); 989 } 990 if (!Cycle || Cycle->isReducible()) 991 return nullptr; 992 993 if (DT.properlyDominates(Cycle->getHeader(), JoinBlock)) 994 return nullptr; 995 996 LLVM_DEBUG(dbgs() << " header " << Context.print(Cycle->getHeader()) 997 << " does not dominate join\n"); 998 999 const auto *Parent = Cycle->getParentCycle(); 1000 while (Parent && !DT.properlyDominates(Parent->getHeader(), JoinBlock)) { 1001 LLVM_DEBUG(dbgs() << " header " << Context.print(Parent->getHeader()) 1002 << " does not dominate join\n"); 1003 Cycle = Parent; 1004 Parent = Parent->getParentCycle(); 1005 } 1006 1007 LLVM_DEBUG(dbgs() << " cycle made divergent by internal branch\n"); 1008 return Cycle; 1009 } 1010 1011 template <typename ContextT, typename CycleT, typename BlockT, 1012 typename DominatorTreeT> 1013 static const CycleT * 1014 getOutermostDivergentCycle(const CycleT *Cycle, const BlockT *DivTermBlock, 1015 const BlockT *JoinBlock, const DominatorTreeT &DT, 1016 ContextT &Context) { 1017 if (!Cycle) 1018 return nullptr; 1019 1020 // First try to expand Cycle to the largest that contains JoinBlock 1021 // but not DivTermBlock. 1022 const auto *Ext = getExtDivCycle(Cycle, DivTermBlock, JoinBlock); 1023 1024 // Continue expanding to the largest cycle that contains both. 1025 const auto *Int = getIntDivCycle(Cycle, DivTermBlock, JoinBlock, DT, Context); 1026 1027 if (Int) 1028 return Int; 1029 return Ext; 1030 } 1031 1032 template <typename ContextT> 1033 bool GenericUniformityAnalysisImpl<ContextT>::isTemporalDivergent( 1034 const BlockT &ObservingBlock, const InstructionT &Def) const { 1035 const BlockT *DefBlock = Def.getParent(); 1036 for (const CycleT *Cycle = CI.getCycle(DefBlock); 1037 Cycle && !Cycle->contains(&ObservingBlock); 1038 Cycle = Cycle->getParentCycle()) { 1039 if (DivergentExitCycles.contains(Cycle)) { 1040 return true; 1041 } 1042 } 1043 return false; 1044 } 1045 1046 template <typename ContextT> 1047 void GenericUniformityAnalysisImpl<ContextT>::analyzeControlDivergence( 1048 const InstructionT &Term) { 1049 const auto *DivTermBlock = Term.getParent(); 1050 DivergentTermBlocks.insert(DivTermBlock); 1051 LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Context.print(DivTermBlock) 1052 << "\n"); 1053 1054 // Don't propagate divergence from unreachable blocks. 1055 if (!DT.isReachableFromEntry(DivTermBlock)) 1056 return; 1057 1058 const auto &DivDesc = SDA.getJoinBlocks(DivTermBlock); 1059 SmallVector<const CycleT *> DivCycles; 1060 1061 // Iterate over all blocks now reachable by a disjoint path join 1062 for (const auto *JoinBlock : DivDesc.JoinDivBlocks) { 1063 const auto *Cycle = CI.getCycle(JoinBlock); 1064 LLVM_DEBUG(dbgs() << "visiting join block " << Context.print(JoinBlock) 1065 << "\n"); 1066 if (const auto *Outermost = getOutermostDivergentCycle( 1067 Cycle, DivTermBlock, JoinBlock, DT, Context)) { 1068 LLVM_DEBUG(dbgs() << "found divergent cycle\n"); 1069 DivCycles.push_back(Outermost); 1070 continue; 1071 } 1072 taintAndPushPhiNodes(*JoinBlock); 1073 } 1074 1075 // Sort by order of decreasing depth. This allows later cycles to be skipped 1076 // because they are already contained in earlier ones. 1077 llvm::sort(DivCycles, [](const CycleT *A, const CycleT *B) { 1078 return A->getDepth() > B->getDepth(); 1079 }); 1080 1081 // Cycles that are assumed divergent due to the diverged entry 1082 // criterion potentially contain temporal divergence depending on 1083 // the DFS chosen. Conservatively, all values produced in such a 1084 // cycle are assumed divergent. "Cycle invariant" values may be 1085 // assumed uniform, but that requires further analysis. 1086 for (auto *C : DivCycles) { 1087 if (!insertIfNotContained(AssumedDivergent, C)) 1088 continue; 1089 LLVM_DEBUG(dbgs() << "process divergent cycle\n"); 1090 for (const BlockT *BB : C->blocks()) { 1091 taintAndPushAllDefs(*BB); 1092 } 1093 } 1094 1095 const auto *BranchCycle = CI.getCycle(DivTermBlock); 1096 assert(DivDesc.CycleDivBlocks.empty() || BranchCycle); 1097 for (const auto *DivExitBlock : DivDesc.CycleDivBlocks) { 1098 propagateCycleExitDivergence(*DivExitBlock, *BranchCycle); 1099 } 1100 } 1101 1102 template <typename ContextT> 1103 void GenericUniformityAnalysisImpl<ContextT>::compute() { 1104 // Initialize worklist. 1105 auto DivValuesCopy = DivergentValues; 1106 for (const auto DivVal : DivValuesCopy) { 1107 assert(isDivergent(DivVal) && "Worklist invariant violated!"); 1108 pushUsers(DivVal); 1109 } 1110 1111 // All values on the Worklist are divergent. 1112 // Their users may not have been updated yet. 1113 while (!Worklist.empty()) { 1114 const InstructionT *I = Worklist.back(); 1115 Worklist.pop_back(); 1116 1117 LLVM_DEBUG(dbgs() << "worklist pop: " << Context.print(I) << "\n"); 1118 1119 if (I->isTerminator()) { 1120 analyzeControlDivergence(*I); 1121 continue; 1122 } 1123 1124 // propagate value divergence to users 1125 assert(isDivergent(*I) && "Worklist invariant violated!"); 1126 pushUsers(*I); 1127 } 1128 } 1129 1130 template <typename ContextT> 1131 bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform( 1132 const InstructionT &Instr) const { 1133 return UniformOverrides.contains(&Instr); 1134 } 1135 1136 template <typename ContextT> 1137 GenericUniformityInfo<ContextT>::GenericUniformityInfo( 1138 FunctionT &Func, const DominatorTreeT &DT, const CycleInfoT &CI, 1139 const TargetTransformInfo *TTI) 1140 : F(&Func) { 1141 DA.reset(new ImplT{Func, DT, CI, TTI}); 1142 } 1143 1144 template <typename ContextT> 1145 void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const { 1146 bool haveDivergentArgs = false; 1147 1148 // Control flow instructions may be divergent even if their inputs are 1149 // uniform. Thus, although exceedingly rare, it is possible to have a program 1150 // with no divergent values but with divergent control structures. 1151 if (DivergentValues.empty() && DivergentTermBlocks.empty() && 1152 DivergentExitCycles.empty()) { 1153 OS << "ALL VALUES UNIFORM\n"; 1154 return; 1155 } 1156 1157 for (const auto &entry : DivergentValues) { 1158 const BlockT *parent = Context.getDefBlock(entry); 1159 if (!parent) { 1160 if (!haveDivergentArgs) { 1161 OS << "DIVERGENT ARGUMENTS:\n"; 1162 haveDivergentArgs = true; 1163 } 1164 OS << " DIVERGENT: " << Context.print(entry) << '\n'; 1165 } 1166 } 1167 1168 if (!AssumedDivergent.empty()) { 1169 OS << "CYCLES ASSSUMED DIVERGENT:\n"; 1170 for (const CycleT *cycle : AssumedDivergent) { 1171 OS << " " << cycle->print(Context) << '\n'; 1172 } 1173 } 1174 1175 if (!DivergentExitCycles.empty()) { 1176 OS << "CYCLES WITH DIVERGENT EXIT:\n"; 1177 for (const CycleT *cycle : DivergentExitCycles) { 1178 OS << " " << cycle->print(Context) << '\n'; 1179 } 1180 } 1181 1182 for (auto &block : F) { 1183 OS << "\nBLOCK " << Context.print(&block) << '\n'; 1184 1185 OS << "DEFINITIONS\n"; 1186 SmallVector<ConstValueRefT, 16> defs; 1187 Context.appendBlockDefs(defs, block); 1188 for (auto value : defs) { 1189 if (isDivergent(value)) 1190 OS << " DIVERGENT: "; 1191 else 1192 OS << " "; 1193 OS << Context.print(value) << '\n'; 1194 } 1195 1196 OS << "TERMINATORS\n"; 1197 SmallVector<const InstructionT *, 8> terms; 1198 Context.appendBlockTerms(terms, block); 1199 bool divergentTerminators = hasDivergentTerminator(block); 1200 for (auto *T : terms) { 1201 if (divergentTerminators) 1202 OS << " DIVERGENT: "; 1203 else 1204 OS << " "; 1205 OS << Context.print(T) << '\n'; 1206 } 1207 1208 OS << "END BLOCK\n"; 1209 } 1210 } 1211 1212 template <typename ContextT> 1213 bool GenericUniformityInfo<ContextT>::hasDivergence() const { 1214 return DA->hasDivergence(); 1215 } 1216 1217 /// Whether \p V is divergent at its definition. 1218 template <typename ContextT> 1219 bool GenericUniformityInfo<ContextT>::isDivergent(ConstValueRefT V) const { 1220 return DA->isDivergent(V); 1221 } 1222 1223 template <typename ContextT> 1224 bool GenericUniformityInfo<ContextT>::isDivergent(const InstructionT *I) const { 1225 return DA->isDivergent(*I); 1226 } 1227 1228 template <typename ContextT> 1229 bool GenericUniformityInfo<ContextT>::isDivergentUse(const UseT &U) const { 1230 return DA->isDivergentUse(U); 1231 } 1232 1233 template <typename ContextT> 1234 bool GenericUniformityInfo<ContextT>::hasDivergentTerminator(const BlockT &B) { 1235 return DA->hasDivergentTerminator(B); 1236 } 1237 1238 /// \brief T helper function for printing. 1239 template <typename ContextT> 1240 void GenericUniformityInfo<ContextT>::print(raw_ostream &out) const { 1241 DA->print(out); 1242 } 1243 1244 template <typename ContextT> 1245 void llvm::ModifiedPostOrder<ContextT>::computeStackPO( 1246 SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI, const CycleT *Cycle, 1247 SmallPtrSetImpl<BlockT *> &Finalized) { 1248 LLVM_DEBUG(dbgs() << "inside computeStackPO\n"); 1249 while (!Stack.empty()) { 1250 auto *NextBB = Stack.back(); 1251 if (Finalized.count(NextBB)) { 1252 Stack.pop_back(); 1253 continue; 1254 } 1255 LLVM_DEBUG(dbgs() << " visiting " << CI.getSSAContext().print(NextBB) 1256 << "\n"); 1257 auto *NestedCycle = CI.getCycle(NextBB); 1258 if (Cycle != NestedCycle && (!Cycle || Cycle->contains(NestedCycle))) { 1259 LLVM_DEBUG(dbgs() << " found a cycle\n"); 1260 while (NestedCycle->getParentCycle() != Cycle) 1261 NestedCycle = NestedCycle->getParentCycle(); 1262 1263 SmallVector<BlockT *, 3> NestedExits; 1264 NestedCycle->getExitBlocks(NestedExits); 1265 bool PushedNodes = false; 1266 for (auto *NestedExitBB : NestedExits) { 1267 LLVM_DEBUG(dbgs() << " examine exit: " 1268 << CI.getSSAContext().print(NestedExitBB) << "\n"); 1269 if (Cycle && !Cycle->contains(NestedExitBB)) 1270 continue; 1271 if (Finalized.count(NestedExitBB)) 1272 continue; 1273 PushedNodes = true; 1274 Stack.push_back(NestedExitBB); 1275 LLVM_DEBUG(dbgs() << " pushed exit: " 1276 << CI.getSSAContext().print(NestedExitBB) << "\n"); 1277 } 1278 if (!PushedNodes) { 1279 // All loop exits finalized -> finish this node 1280 Stack.pop_back(); 1281 computeCyclePO(CI, NestedCycle, Finalized); 1282 } 1283 continue; 1284 } 1285 1286 LLVM_DEBUG(dbgs() << " no nested cycle, going into DAG\n"); 1287 // DAG-style 1288 bool PushedNodes = false; 1289 for (auto *SuccBB : successors(NextBB)) { 1290 LLVM_DEBUG(dbgs() << " examine succ: " 1291 << CI.getSSAContext().print(SuccBB) << "\n"); 1292 if (Cycle && !Cycle->contains(SuccBB)) 1293 continue; 1294 if (Finalized.count(SuccBB)) 1295 continue; 1296 PushedNodes = true; 1297 Stack.push_back(SuccBB); 1298 LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(SuccBB) 1299 << "\n"); 1300 } 1301 if (!PushedNodes) { 1302 // Never push nodes twice 1303 LLVM_DEBUG(dbgs() << " finishing node: " 1304 << CI.getSSAContext().print(NextBB) << "\n"); 1305 Stack.pop_back(); 1306 Finalized.insert(NextBB); 1307 appendBlock(*NextBB); 1308 } 1309 } 1310 LLVM_DEBUG(dbgs() << "exited computeStackPO\n"); 1311 } 1312 1313 template <typename ContextT> 1314 void ModifiedPostOrder<ContextT>::computeCyclePO( 1315 const CycleInfoT &CI, const CycleT *Cycle, 1316 SmallPtrSetImpl<BlockT *> &Finalized) { 1317 LLVM_DEBUG(dbgs() << "inside computeCyclePO\n"); 1318 SmallVector<BlockT *> Stack; 1319 auto *CycleHeader = Cycle->getHeader(); 1320 1321 LLVM_DEBUG(dbgs() << " noted header: " 1322 << CI.getSSAContext().print(CycleHeader) << "\n"); 1323 assert(!Finalized.count(CycleHeader)); 1324 Finalized.insert(CycleHeader); 1325 1326 // Visit the header last 1327 LLVM_DEBUG(dbgs() << " finishing header: " 1328 << CI.getSSAContext().print(CycleHeader) << "\n"); 1329 appendBlock(*CycleHeader, Cycle->isReducible()); 1330 1331 // Initialize with immediate successors 1332 for (auto *BB : successors(CycleHeader)) { 1333 LLVM_DEBUG(dbgs() << " examine succ: " << CI.getSSAContext().print(BB) 1334 << "\n"); 1335 if (!Cycle->contains(BB)) 1336 continue; 1337 if (BB == CycleHeader) 1338 continue; 1339 if (!Finalized.count(BB)) { 1340 LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(BB) 1341 << "\n"); 1342 Stack.push_back(BB); 1343 } 1344 } 1345 1346 // Compute PO inside region 1347 computeStackPO(Stack, CI, Cycle, Finalized); 1348 1349 LLVM_DEBUG(dbgs() << "exited computeCyclePO\n"); 1350 } 1351 1352 /// \brief Generically compute the modified post order. 1353 template <typename ContextT> 1354 void llvm::ModifiedPostOrder<ContextT>::compute(const CycleInfoT &CI) { 1355 SmallPtrSet<BlockT *, 32> Finalized; 1356 SmallVector<BlockT *> Stack; 1357 auto *F = CI.getFunction(); 1358 Stack.reserve(24); // FIXME made-up number 1359 Stack.push_back(GraphTraits<FunctionT *>::getEntryNode(F)); 1360 computeStackPO(Stack, CI, nullptr, Finalized); 1361 } 1362 1363 } // namespace llvm 1364 1365 #undef DEBUG_TYPE 1366 1367 #endif // LLVM_ADT_GENERICUNIFORMITYIMPL_H 1368