1 //===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // \file 10 // The divergence analysis determines which instructions and branches are 11 // divergent given a set of divergent source instructions. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef LLVM_ANALYSIS_DIVERGENCEANALYSIS_H 16 #define LLVM_ANALYSIS_DIVERGENCEANALYSIS_H 17 18 #include "llvm/ADT/DenseSet.h" 19 #include "llvm/Analysis/SyncDependenceAnalysis.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/Pass.h" 22 #include <vector> 23 24 namespace llvm { 25 class Module; 26 class Value; 27 class Instruction; 28 class Loop; 29 class raw_ostream; 30 class TargetTransformInfo; 31 32 /// \brief Generic divergence analysis for reducible CFGs. 33 /// 34 /// This analysis propagates divergence in a data-parallel context from sources 35 /// of divergence to all users. It requires reducible CFGs. All assignments 36 /// should be in SSA form. 37 class DivergenceAnalysisImpl { 38 public: 39 /// \brief This instance will analyze the whole function \p F or the loop \p 40 /// RegionLoop. 41 /// 42 /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop. 43 /// Otherwise the whole function is analyzed. 44 /// \param IsLCSSAForm whether the analysis may assume that the IR in the 45 /// region in in LCSSA form. 46 DivergenceAnalysisImpl(const Function &F, const Loop *RegionLoop, 47 const DominatorTree &DT, const LoopInfo &LI, 48 SyncDependenceAnalysis &SDA, bool IsLCSSAForm); 49 50 /// \brief The loop that defines the analyzed region (if any). getRegionLoop()51 const Loop *getRegionLoop() const { return RegionLoop; } getFunction()52 const Function &getFunction() const { return F; } 53 54 /// \brief Whether \p BB is part of the region. 55 bool inRegion(const BasicBlock &BB) const; 56 /// \brief Whether \p I is part of the region. 57 bool inRegion(const Instruction &I) const; 58 59 /// \brief Mark \p UniVal as a value that is always uniform. 60 void addUniformOverride(const Value &UniVal); 61 62 /// \brief Mark \p DivVal as a value that is always divergent. Will not do so 63 /// if `isAlwaysUniform(DivVal)`. 64 /// \returns Whether the tracked divergence state of \p DivVal changed. 65 bool markDivergent(const Value &DivVal); 66 67 /// \brief Propagate divergence to all instructions in the region. 68 /// Divergence is seeded by calls to \p markDivergent. 69 void compute(); 70 71 /// \brief Whether any value was marked or analyzed to be divergent. hasDetectedDivergence()72 bool hasDetectedDivergence() const { return !DivergentValues.empty(); } 73 74 /// \brief Whether \p Val will always return a uniform value regardless of its 75 /// operands 76 bool isAlwaysUniform(const Value &Val) const; 77 78 /// \brief Whether \p Val is divergent at its definition. 79 bool isDivergent(const Value &Val) const; 80 81 /// \brief Whether \p U is divergent. Uses of a uniform value can be 82 /// divergent. 83 bool isDivergentUse(const Use &U) const; 84 85 private: 86 /// \brief Mark \p Term as divergent and push all Instructions that become 87 /// divergent as a result on the worklist. 88 void analyzeControlDivergence(const Instruction &Term); 89 /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on 90 /// the worklist. 91 void taintAndPushPhiNodes(const BasicBlock &JoinBlock); 92 93 /// \brief Identify all Instructions that become divergent because \p DivExit 94 /// is a divergent loop exit of \p DivLoop. Mark those instructions as 95 /// divergent and push them on the worklist. 96 void propagateLoopExitDivergence(const BasicBlock &DivExit, 97 const Loop &DivLoop); 98 99 /// \brief Internal implementation function for propagateLoopExitDivergence. 100 void analyzeLoopExitDivergence(const BasicBlock &DivExit, 101 const Loop &OuterDivLoop); 102 103 /// \brief Mark all instruction as divergent that use a value defined in \p 104 /// OuterDivLoop. Push their users on the worklist. 105 void analyzeTemporalDivergence(const Instruction &I, 106 const Loop &OuterDivLoop); 107 108 /// \brief Push all users of \p Val (in the region) to the worklist. 109 void pushUsers(const Value &I); 110 111 /// \brief Whether \p Val is divergent when read in \p ObservingBlock. 112 bool isTemporalDivergent(const BasicBlock &ObservingBlock, 113 const Value &Val) const; 114 115 private: 116 const Function &F; 117 // If regionLoop != nullptr, analysis is only performed within \p RegionLoop. 118 // Otherwise, analyze the whole function 119 const Loop *RegionLoop; 120 121 const DominatorTree &DT; 122 const LoopInfo &LI; 123 124 // Recognized divergent loops 125 DenseSet<const Loop *> DivergentLoops; 126 127 // The SDA links divergent branches to divergent control-flow joins. 128 SyncDependenceAnalysis &SDA; 129 130 // Use simplified code path for LCSSA form. 131 bool IsLCSSAForm; 132 133 // Set of known-uniform values. 134 DenseSet<const Value *> UniformOverrides; 135 136 // Detected/marked divergent values. 137 DenseSet<const Value *> DivergentValues; 138 139 // Internal worklist for divergence propagation. 140 std::vector<const Instruction *> Worklist; 141 }; 142 143 class DivergenceInfo { 144 Function &F; 145 146 // If the function contains an irreducible region the divergence 147 // analysis can run indefinitely. We set ContainsIrreducible and no 148 // analysis is actually performed on the function. All values in 149 // this function are conservatively reported as divergent instead. 150 bool ContainsIrreducible; 151 std::unique_ptr<SyncDependenceAnalysis> SDA; 152 std::unique_ptr<DivergenceAnalysisImpl> DA; 153 154 public: 155 DivergenceInfo(Function &F, const DominatorTree &DT, 156 const PostDominatorTree &PDT, const LoopInfo &LI, 157 const TargetTransformInfo &TTI, bool KnownReducible); 158 159 /// Whether any divergence was detected. hasDivergence()160 bool hasDivergence() const { 161 return ContainsIrreducible || DA->hasDetectedDivergence(); 162 } 163 164 /// The GPU kernel this analysis result is for getFunction()165 const Function &getFunction() const { return F; } 166 167 /// Whether \p V is divergent at its definition. isDivergent(const Value & V)168 bool isDivergent(const Value &V) const { 169 return ContainsIrreducible || DA->isDivergent(V); 170 } 171 172 /// Whether \p U is divergent. Uses of a uniform value can be divergent. isDivergentUse(const Use & U)173 bool isDivergentUse(const Use &U) const { 174 return ContainsIrreducible || DA->isDivergentUse(U); 175 } 176 177 /// Whether \p V is uniform/non-divergent. isUniform(const Value & V)178 bool isUniform(const Value &V) const { return !isDivergent(V); } 179 180 /// Whether \p U is uniform/non-divergent. Uses of a uniform value can be 181 /// divergent. isUniformUse(const Use & U)182 bool isUniformUse(const Use &U) const { return !isDivergentUse(U); } 183 }; 184 185 /// \brief Divergence analysis frontend for GPU kernels. 186 class DivergenceAnalysis : public AnalysisInfoMixin<DivergenceAnalysis> { 187 friend AnalysisInfoMixin<DivergenceAnalysis>; 188 189 static AnalysisKey Key; 190 191 public: 192 using Result = DivergenceInfo; 193 194 /// Runs the divergence analysis on @F, a GPU kernel 195 Result run(Function &F, FunctionAnalysisManager &AM); 196 }; 197 198 /// Printer pass to dump divergence analysis results. 199 struct DivergenceAnalysisPrinterPass 200 : public PassInfoMixin<DivergenceAnalysisPrinterPass> { DivergenceAnalysisPrinterPassDivergenceAnalysisPrinterPass201 DivergenceAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {} 202 203 PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); 204 205 private: 206 raw_ostream &OS; 207 }; // class DivergenceAnalysisPrinterPass 208 209 } // namespace llvm 210 211 #endif // LLVM_ANALYSIS_DIVERGENCEANALYSIS_H 212