1 //===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // \file 10 // The divergence analysis determines which instructions and branches are 11 // divergent given a set of divergent source instructions. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef LLVM_ANALYSIS_DIVERGENCEANALYSIS_H 16 #define LLVM_ANALYSIS_DIVERGENCEANALYSIS_H 17 18 #include "llvm/ADT/DenseSet.h" 19 #include "llvm/Analysis/SyncDependenceAnalysis.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/Pass.h" 22 #include <vector> 23 24 namespace llvm { 25 class Module; 26 class Value; 27 class Instruction; 28 class Loop; 29 class raw_ostream; 30 class TargetTransformInfo; 31 32 /// \brief Generic divergence analysis for reducible CFGs. 33 /// 34 /// This analysis propagates divergence in a data-parallel context from sources 35 /// of divergence to all users. It requires reducible CFGs. All assignments 36 /// should be in SSA form. 37 class DivergenceAnalysisImpl { 38 public: 39 /// \brief This instance will analyze the whole function \p F or the loop \p 40 /// RegionLoop. 41 /// 42 /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop. 43 /// Otherwise the whole function is analyzed. 44 /// \param IsLCSSAForm whether the analysis may assume that the IR in the 45 /// region in in LCSSA form. 46 DivergenceAnalysisImpl(const Function &F, const Loop *RegionLoop, 47 const DominatorTree &DT, const LoopInfo &LI, 48 SyncDependenceAnalysis &SDA, bool IsLCSSAForm); 49 50 /// \brief The loop that defines the analyzed region (if any). getRegionLoop()51 const Loop *getRegionLoop() const { return RegionLoop; } getFunction()52 const Function &getFunction() const { return F; } 53 54 /// \brief Whether \p BB is part of the region. 55 bool inRegion(const BasicBlock &BB) const; 56 /// \brief Whether \p I is part of the region. 57 bool inRegion(const Instruction &I) const; 58 59 /// \brief Mark \p UniVal as a value that is always uniform. 60 void addUniformOverride(const Value &UniVal); 61 62 /// \brief Mark \p DivVal as a value that is always divergent. Will not do so 63 /// if `isAlwaysUniform(DivVal)`. 64 /// \returns Whether the tracked divergence state of \p DivVal changed. 65 bool markDivergent(const Value &DivVal); 66 67 /// \brief Propagate divergence to all instructions in the region. 68 /// Divergence is seeded by calls to \p markDivergent. 69 void compute(); 70 71 /// \brief Whether any value was marked or analyzed to be divergent. hasDetectedDivergence()72 bool hasDetectedDivergence() const { return !DivergentValues.empty(); } 73 74 /// \brief Whether \p Val will always return a uniform value regardless of its 75 /// operands 76 bool isAlwaysUniform(const Value &Val) const; 77 78 /// \brief Whether \p Val is divergent at its definition. 79 bool isDivergent(const Value &Val) const; 80 81 /// \brief Whether \p U is divergent. Uses of a uniform value can be 82 /// divergent. 83 bool isDivergentUse(const Use &U) const; 84 85 private: 86 /// \brief Mark \p Term as divergent and push all Instructions that become 87 /// divergent as a result on the worklist. 88 void analyzeControlDivergence(const Instruction &Term); 89 /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on 90 /// the worklist. 91 void taintAndPushPhiNodes(const BasicBlock &JoinBlock); 92 93 /// \brief Identify all Instructions that become divergent because \p DivExit 94 /// is a divergent loop exit of \p DivLoop. Mark those instructions as 95 /// divergent and push them on the worklist. 96 void propagateLoopExitDivergence(const BasicBlock &DivExit, 97 const Loop &DivLoop); 98 99 /// \brief Internal implementation function for propagateLoopExitDivergence. 100 void analyzeLoopExitDivergence(const BasicBlock &DivExit, 101 const Loop &OuterDivLoop); 102 103 /// \brief Mark all instruction as divergent that use a value defined in \p 104 /// OuterDivLoop. Push their users on the worklist. 105 void analyzeTemporalDivergence(const Instruction &I, 106 const Loop &OuterDivLoop); 107 108 /// \brief Push all users of \p Val (in the region) to the worklist. 109 void pushUsers(const Value &I); 110 111 /// \brief Whether \p Val is divergent when read in \p ObservingBlock. 112 bool isTemporalDivergent(const BasicBlock &ObservingBlock, 113 const Value &Val) const; 114 115 /// \brief Whether \p Block is join divergent 116 /// 117 /// (see markBlockJoinDivergent). isJoinDivergent(const BasicBlock & Block)118 bool isJoinDivergent(const BasicBlock &Block) const { 119 return DivergentJoinBlocks.contains(&Block); 120 } 121 122 private: 123 const Function &F; 124 // If regionLoop != nullptr, analysis is only performed within \p RegionLoop. 125 // Otherwise, analyze the whole function 126 const Loop *RegionLoop; 127 128 const DominatorTree &DT; 129 const LoopInfo &LI; 130 131 // Recognized divergent loops 132 DenseSet<const Loop *> DivergentLoops; 133 134 // The SDA links divergent branches to divergent control-flow joins. 135 SyncDependenceAnalysis &SDA; 136 137 // Use simplified code path for LCSSA form. 138 bool IsLCSSAForm; 139 140 // Set of known-uniform values. 141 DenseSet<const Value *> UniformOverrides; 142 143 // Blocks with joining divergent control from different predecessors. 144 DenseSet<const BasicBlock *> DivergentJoinBlocks; // FIXME Deprecated 145 146 // Detected/marked divergent values. 147 DenseSet<const Value *> DivergentValues; 148 149 // Internal worklist for divergence propagation. 150 std::vector<const Instruction *> Worklist; 151 }; 152 153 class DivergenceInfo { 154 Function &F; 155 156 // If the function contains an irreducible region the divergence 157 // analysis can run indefinitely. We set ContainsIrreducible and no 158 // analysis is actually performed on the function. All values in 159 // this function are conservatively reported as divergent instead. 160 bool ContainsIrreducible; 161 std::unique_ptr<SyncDependenceAnalysis> SDA; 162 std::unique_ptr<DivergenceAnalysisImpl> DA; 163 164 public: 165 DivergenceInfo(Function &F, const DominatorTree &DT, 166 const PostDominatorTree &PDT, const LoopInfo &LI, 167 const TargetTransformInfo &TTI, bool KnownReducible); 168 169 /// Whether any divergence was detected. hasDivergence()170 bool hasDivergence() const { 171 return ContainsIrreducible || DA->hasDetectedDivergence(); 172 } 173 174 /// The GPU kernel this analysis result is for getFunction()175 const Function &getFunction() const { return F; } 176 177 /// Whether \p V is divergent at its definition. isDivergent(const Value & V)178 bool isDivergent(const Value &V) const { 179 return ContainsIrreducible || DA->isDivergent(V); 180 } 181 182 /// Whether \p U is divergent. Uses of a uniform value can be divergent. isDivergentUse(const Use & U)183 bool isDivergentUse(const Use &U) const { 184 return ContainsIrreducible || DA->isDivergentUse(U); 185 } 186 187 /// Whether \p V is uniform/non-divergent. isUniform(const Value & V)188 bool isUniform(const Value &V) const { return !isDivergent(V); } 189 190 /// Whether \p U is uniform/non-divergent. Uses of a uniform value can be 191 /// divergent. isUniformUse(const Use & U)192 bool isUniformUse(const Use &U) const { return !isDivergentUse(U); } 193 }; 194 195 /// \brief Divergence analysis frontend for GPU kernels. 196 class DivergenceAnalysis : public AnalysisInfoMixin<DivergenceAnalysis> { 197 friend AnalysisInfoMixin<DivergenceAnalysis>; 198 199 static AnalysisKey Key; 200 201 public: 202 using Result = DivergenceInfo; 203 204 /// Runs the divergence analysis on @F, a GPU kernel 205 Result run(Function &F, FunctionAnalysisManager &AM); 206 }; 207 208 /// Printer pass to dump divergence analysis results. 209 struct DivergenceAnalysisPrinterPass 210 : public PassInfoMixin<DivergenceAnalysisPrinterPass> { DivergenceAnalysisPrinterPassDivergenceAnalysisPrinterPass211 DivergenceAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {} 212 213 PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); 214 215 private: 216 raw_ostream &OS; 217 }; // class DivergenceAnalysisPrinterPass 218 219 } // namespace llvm 220 221 #endif // LLVM_ANALYSIS_DIVERGENCEANALYSIS_H 222