1 //===- GenericUniformityImpl.h -----------------------*- C++ -*------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This template implementation resides in a separate file so that it
10 // does not get injected into every .cpp file that includes the
11 // generic header.
12 //
13 // DO NOT INCLUDE THIS FILE WHEN MERELY USING UNIFORMITYINFO.
14 //
15 // This file should only be included by files that implement a
16 // specialization of the relvant templates. Currently these are:
17 // - UniformityAnalysis.cpp
18 //
19 // Note: The DEBUG_TYPE macro should be defined before using this
20 // file so that any use of LLVM_DEBUG is associated with the
21 // including file rather than this file.
22 //
23 //===----------------------------------------------------------------------===//
24 ///
25 /// \file
26 /// \brief Implementation of uniformity analysis.
27 ///
28 /// The algorithm is a fixed point iteration that starts with the assumption
29 /// that all control flow and all values are uniform. Starting from sources of
30 /// divergence (whose discovery must be implemented by a CFG- or even
31 /// target-specific derived class), divergence of values is propagated from
32 /// definition to uses in a straight-forward way. The main complexity lies in
33 /// the propagation of the impact of divergent control flow on the divergence of
34 /// values (sync dependencies).
35 ///
36 /// NOTE: In general, no interface exists for a transform to update
37 /// (Machine)UniformityInfo. Additionally, (Machine)CycleAnalysis is a
38 /// transitive dependence, but it also does not provide an interface for
39 /// updating itself. Given that, transforms should not preserve uniformity in
40 /// their getAnalysisUsage() callback.
41 ///
42 //===----------------------------------------------------------------------===//
43 
44 #ifndef LLVM_ADT_GENERICUNIFORMITYIMPL_H
45 #define LLVM_ADT_GENERICUNIFORMITYIMPL_H
46 
47 #include "llvm/ADT/GenericUniformityInfo.h"
48 
49 #include "llvm/ADT/SmallPtrSet.h"
50 #include "llvm/ADT/SparseBitVector.h"
51 #include "llvm/ADT/StringExtras.h"
52 #include "llvm/Support/raw_ostream.h"
53 
54 #include <set>
55 
56 #define DEBUG_TYPE "uniformity"
57 
58 namespace llvm {
59 
unique(Range && R)60 template <typename Range> auto unique(Range &&R) {
61   return std::unique(adl_begin(R), adl_end(R));
62 }
63 
64 /// Construct a specially modified post-order traversal of cycles.
65 ///
66 /// The ModifiedPO is contructed using a virtually modified CFG as follows:
67 ///
68 /// 1. The successors of pre-entry nodes (predecessors of an cycle
69 ///    entry that are outside the cycle) are replaced by the
70 ///    successors of the successors of the header.
71 /// 2. Successors of the cycle header are replaced by the exit blocks
72 ///    of the cycle.
73 ///
74 /// Effectively, we produce a depth-first numbering with the following
75 /// properties:
76 ///
77 /// 1. Nodes after a cycle are numbered earlier than the cycle header.
78 /// 2. The header is numbered earlier than the nodes in the cycle.
79 /// 3. The numbering of the nodes within the cycle forms an interval
80 ///    starting with the header.
81 ///
82 /// Effectively, the virtual modification arranges the nodes in a
83 /// cycle as a DAG with the header as the sole leaf, and successors of
84 /// the header as the roots. A reverse traversal of this numbering has
85 /// the following invariant on the unmodified original CFG:
86 ///
87 ///    Each node is visited after all its predecessors, except if that
88 ///    predecessor is the cycle header.
89 ///
90 template <typename ContextT> class ModifiedPostOrder {
91 public:
92   using BlockT = typename ContextT::BlockT;
93   using FunctionT = typename ContextT::FunctionT;
94   using DominatorTreeT = typename ContextT::DominatorTreeT;
95 
96   using CycleInfoT = GenericCycleInfo<ContextT>;
97   using CycleT = typename CycleInfoT::CycleT;
98   using const_iterator = typename std::vector<BlockT *>::const_iterator;
99 
ModifiedPostOrder(const ContextT & C)100   ModifiedPostOrder(const ContextT &C) : Context(C) {}
101 
empty()102   bool empty() const { return m_order.empty(); }
size()103   size_t size() const { return m_order.size(); }
104 
clear()105   void clear() { m_order.clear(); }
106   void compute(const CycleInfoT &CI);
107 
count(BlockT * BB)108   unsigned count(BlockT *BB) const { return POIndex.count(BB); }
109   const BlockT *operator[](size_t idx) const { return m_order[idx]; }
110 
111   void appendBlock(const BlockT &BB, bool isReducibleCycleHeader = false) {
112     POIndex[&BB] = m_order.size();
113     m_order.push_back(&BB);
114     LLVM_DEBUG(dbgs() << "ModifiedPO(" << POIndex[&BB]
115                       << "): " << Context.print(&BB) << "\n");
116     if (isReducibleCycleHeader)
117       ReducibleCycleHeaders.insert(&BB);
118   }
119 
getIndex(const BlockT * BB)120   unsigned getIndex(const BlockT *BB) const {
121     assert(POIndex.count(BB));
122     return POIndex.lookup(BB);
123   }
124 
isReducibleCycleHeader(const BlockT * BB)125   bool isReducibleCycleHeader(const BlockT *BB) const {
126     return ReducibleCycleHeaders.contains(BB);
127   }
128 
129 private:
130   SmallVector<const BlockT *> m_order;
131   DenseMap<const BlockT *, unsigned> POIndex;
132   SmallPtrSet<const BlockT *, 32> ReducibleCycleHeaders;
133   const ContextT &Context;
134 
135   void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle,
136                       SmallPtrSetImpl<const BlockT *> &Finalized);
137 
138   void computeStackPO(SmallVectorImpl<const BlockT *> &Stack,
139                       const CycleInfoT &CI, const CycleT *Cycle,
140                       SmallPtrSetImpl<const BlockT *> &Finalized);
141 };
142 
143 template <typename> class DivergencePropagator;
144 
145 /// \class GenericSyncDependenceAnalysis
146 ///
147 /// \brief Locate join blocks for disjoint paths starting at a divergent branch.
148 ///
149 /// An analysis per divergent branch that returns the set of basic
150 /// blocks whose phi nodes become divergent due to divergent control.
151 /// These are the blocks that are reachable by two disjoint paths from
152 /// the branch, or cycle exits reachable along a path that is disjoint
153 /// from a path to the cycle latch.
154 
155 // --- Above line is not a doxygen comment; intentionally left blank ---
156 //
157 // Originally implemented in SyncDependenceAnalysis.cpp for DivergenceAnalysis.
158 //
159 // The SyncDependenceAnalysis is used in the UniformityAnalysis to model
160 // control-induced divergence in phi nodes.
161 //
162 // -- Reference --
163 // The algorithm is an extension of Section 5 of
164 //
165 //   An abstract interpretation for SPMD divergence
166 //       on reducible control flow graphs.
167 //   Julian Rosemann, Simon Moll and Sebastian Hack
168 //   POPL '21
169 //
170 //
171 // -- Sync dependence --
172 // Sync dependence characterizes the control flow aspect of the
173 // propagation of branch divergence. For example,
174 //
175 //   %cond = icmp slt i32 %tid, 10
176 //   br i1 %cond, label %then, label %else
177 // then:
178 //   br label %merge
179 // else:
180 //   br label %merge
181 // merge:
182 //   %a = phi i32 [ 0, %then ], [ 1, %else ]
183 //
184 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
185 // because %tid is not on its use-def chains, %a is sync dependent on %tid
186 // because the branch "br i1 %cond" depends on %tid and affects which value %a
187 // is assigned to.
188 //
189 //
190 // -- Reduction to SSA construction --
191 // There are two disjoint paths from A to X, if a certain variant of SSA
192 // construction places a phi node in X under the following set-up scheme.
193 //
194 // This variant of SSA construction ignores incoming undef values.
195 // That is paths from the entry without a definition do not result in
196 // phi nodes.
197 //
198 //       entry
199 //     /      \
200 //    A        \
201 //  /   \       Y
202 // B     C     /
203 //  \   /  \  /
204 //    D     E
205 //     \   /
206 //       F
207 //
208 // Assume that A contains a divergent branch. We are interested
209 // in the set of all blocks where each block is reachable from A
210 // via two disjoint paths. This would be the set {D, F} in this
211 // case.
212 // To generally reduce this query to SSA construction we introduce
213 // a virtual variable x and assign to x different values in each
214 // successor block of A.
215 //
216 //           entry
217 //         /      \
218 //        A        \
219 //      /   \       Y
220 // x = 0   x = 1   /
221 //      \  /   \  /
222 //        D     E
223 //         \   /
224 //           F
225 //
226 // Our flavor of SSA construction for x will construct the following
227 //
228 //            entry
229 //          /      \
230 //         A        \
231 //       /   \       Y
232 // x0 = 0   x1 = 1  /
233 //       \   /   \ /
234 //     x2 = phi   E
235 //         \     /
236 //         x3 = phi
237 //
238 // The blocks D and F contain phi nodes and are thus each reachable
239 // by two disjoins paths from A.
240 //
241 // -- Remarks --
242 // * In case of cycle exits we need to check for temporal divergence.
243 //   To this end, we check whether the definition of x differs between the
244 //   cycle exit and the cycle header (_after_ SSA construction).
245 //
246 // * In the presence of irreducible control flow, the fixed point is
247 //   reached only after multiple iterations. This is because labels
248 //   reaching the header of a cycle must be repropagated through the
249 //   cycle. This is true even in a reducible cycle, since the labels
250 //   may have been produced by a nested irreducible cycle.
251 //
252 // * Note that SyncDependenceAnalysis is not concerned with the points
253 //   of convergence in an irreducible cycle. It's only purpose is to
254 //   identify join blocks. The "diverged entry" criterion is
255 //   separately applied on join blocks to determine if an entire
256 //   irreducible cycle is assumed to be divergent.
257 //
258 // * Relevant related work:
259 //     A simple algorithm for global data flow analysis problems.
260 //     Matthew S. Hecht and Jeffrey D. Ullman.
261 //     SIAM Journal on Computing, 4(4):519–532, December 1975.
262 //
263 template <typename ContextT> class GenericSyncDependenceAnalysis {
264 public:
265   using BlockT = typename ContextT::BlockT;
266   using DominatorTreeT = typename ContextT::DominatorTreeT;
267   using FunctionT = typename ContextT::FunctionT;
268   using ValueRefT = typename ContextT::ValueRefT;
269   using InstructionT = typename ContextT::InstructionT;
270 
271   using CycleInfoT = GenericCycleInfo<ContextT>;
272   using CycleT = typename CycleInfoT::CycleT;
273 
274   using ConstBlockSet = SmallPtrSet<const BlockT *, 4>;
275   using ModifiedPO = ModifiedPostOrder<ContextT>;
276 
277   // * if BlockLabels[B] == C then C is the dominating definition at
278   //   block B
279   // * if BlockLabels[B] == nullptr then we haven't seen B yet
280   // * if BlockLabels[B] == B then:
281   //   - B is a join point of disjoint paths from X, or,
282   //   - B is an immediate successor of X (initial value), or,
283   //   - B is X
284   using BlockLabelMap = DenseMap<const BlockT *, const BlockT *>;
285 
286   /// Information discovered by the sync dependence analysis for each
287   /// divergent branch.
288   struct DivergenceDescriptor {
289     // Join points of diverged paths.
290     ConstBlockSet JoinDivBlocks;
291     // Divergent cycle exits
292     ConstBlockSet CycleDivBlocks;
293     // Labels assigned to blocks on diverged paths.
294     BlockLabelMap BlockLabels;
295   };
296 
297   using DivergencePropagatorT = DivergencePropagator<ContextT>;
298 
299   GenericSyncDependenceAnalysis(const ContextT &Context,
300                                 const DominatorTreeT &DT, const CycleInfoT &CI);
301 
302   /// \brief Computes divergent join points and cycle exits caused by branch
303   /// divergence in \p Term.
304   ///
305   /// This returns a pair of sets:
306   /// * The set of blocks which are reachable by disjoint paths from
307   ///   \p Term.
308   /// * The set also contains cycle exits if there two disjoint paths:
309   ///   one from \p Term to the cycle exit and another from \p Term to
310   ///   the cycle header.
311   const DivergenceDescriptor &getJoinBlocks(const BlockT *DivTermBlock);
312 
313 private:
314   static DivergenceDescriptor EmptyDivergenceDesc;
315 
316   ModifiedPO CyclePO;
317 
318   const DominatorTreeT &DT;
319   const CycleInfoT &CI;
320 
321   DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
322       CachedControlDivDescs;
323 };
324 
325 /// \brief Analysis that identifies uniform values in a data-parallel
326 /// execution.
327 ///
328 /// This analysis propagates divergence in a data-parallel context
329 /// from sources of divergence to all users. It can be instantiated
330 /// for an IR that provides a suitable SSAContext.
331 template <typename ContextT> class GenericUniformityAnalysisImpl {
332 public:
333   using BlockT = typename ContextT::BlockT;
334   using FunctionT = typename ContextT::FunctionT;
335   using ValueRefT = typename ContextT::ValueRefT;
336   using ConstValueRefT = typename ContextT::ConstValueRefT;
337   using UseT = typename ContextT::UseT;
338   using InstructionT = typename ContextT::InstructionT;
339   using DominatorTreeT = typename ContextT::DominatorTreeT;
340 
341   using CycleInfoT = GenericCycleInfo<ContextT>;
342   using CycleT = typename CycleInfoT::CycleT;
343 
344   using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
345   using DivergenceDescriptorT =
346       typename SyncDependenceAnalysisT::DivergenceDescriptor;
347   using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
348 
GenericUniformityAnalysisImpl(const DominatorTreeT & DT,const CycleInfoT & CI,const TargetTransformInfo * TTI)349   GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI,
350                                 const TargetTransformInfo *TTI)
351       : Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
352         TTI(TTI), DT(DT), SDA(Context, DT, CI) {}
353 
354   void initialize();
355 
getFunction()356   const FunctionT &getFunction() const { return F; }
357 
358   /// \brief Mark \p UniVal as a value that is always uniform.
359   void addUniformOverride(const InstructionT &Instr);
360 
361   /// \brief Examine \p I for divergent outputs and add to the worklist.
362   void markDivergent(const InstructionT &I);
363 
364   /// \brief Mark \p DivVal as a divergent value.
365   /// \returns Whether the tracked divergence state of \p DivVal changed.
366   bool markDivergent(ConstValueRefT DivVal);
367 
368   /// \brief Mark outputs of \p Instr as divergent.
369   /// \returns Whether the tracked divergence state of any output has changed.
370   bool markDefsDivergent(const InstructionT &Instr);
371 
372   /// \brief Propagate divergence to all instructions in the region.
373   /// Divergence is seeded by calls to \p markDivergent.
374   void compute();
375 
376   /// \brief Whether any value was marked or analyzed to be divergent.
hasDivergence()377   bool hasDivergence() const { return !DivergentValues.empty(); }
378 
379   /// \brief Whether \p Val will always return a uniform value regardless of its
380   /// operands
381   bool isAlwaysUniform(const InstructionT &Instr) const;
382 
383   bool hasDivergentDefs(const InstructionT &I) const;
384 
isDivergent(const InstructionT & I)385   bool isDivergent(const InstructionT &I) const {
386     if (I.isTerminator()) {
387       return DivergentTermBlocks.contains(I.getParent());
388     }
389     return hasDivergentDefs(I);
390   };
391 
392   /// \brief Whether \p Val is divergent at its definition.
isDivergent(ConstValueRefT V)393   bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); }
394 
395   bool isDivergentUse(const UseT &U) const;
396 
hasDivergentTerminator(const BlockT & B)397   bool hasDivergentTerminator(const BlockT &B) const {
398     return DivergentTermBlocks.contains(&B);
399   }
400 
401   void print(raw_ostream &out) const;
402 
403 protected:
404   /// \brief Value/block pair representing a single phi input.
405   struct PhiInput {
406     ConstValueRefT value;
407     BlockT *predBlock;
408 
PhiInputPhiInput409     PhiInput(ConstValueRefT value, BlockT *predBlock)
410         : value(value), predBlock(predBlock) {}
411   };
412 
413   const ContextT &Context;
414   const FunctionT &F;
415   const CycleInfoT &CI;
416   const TargetTransformInfo *TTI = nullptr;
417 
418   // Detected/marked divergent values.
419   std::set<ConstValueRefT> DivergentValues;
420   SmallPtrSet<const BlockT *, 32> DivergentTermBlocks;
421 
422   // Internal worklist for divergence propagation.
423   std::vector<const InstructionT *> Worklist;
424 
425   /// \brief Mark \p Term as divergent and push all Instructions that become
426   /// divergent as a result on the worklist.
427   void analyzeControlDivergence(const InstructionT &Term);
428 
429 private:
430   const DominatorTreeT &DT;
431 
432   // Recognized cycles with divergent exits.
433   SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
434 
435   // Cycles assumed to be divergent.
436   //
437   // We don't use a set here because every insertion needs an explicit
438   // traversal of all existing members.
439   SmallVector<const CycleT *> AssumedDivergent;
440 
441   // The SDA links divergent branches to divergent control-flow joins.
442   SyncDependenceAnalysisT SDA;
443 
444   // Set of known-uniform values.
445   SmallPtrSet<const InstructionT *, 32> UniformOverrides;
446 
447   /// \brief Mark all nodes in \p JoinBlock as divergent and push them on
448   /// the worklist.
449   void taintAndPushAllDefs(const BlockT &JoinBlock);
450 
451   /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on
452   /// the worklist.
453   void taintAndPushPhiNodes(const BlockT &JoinBlock);
454 
455   /// \brief Identify all Instructions that become divergent because \p DivExit
456   /// is a divergent cycle exit of \p DivCycle. Mark those instructions as
457   /// divergent and push them on the worklist.
458   void propagateCycleExitDivergence(const BlockT &DivExit,
459                                     const CycleT &DivCycle);
460 
461   /// Mark as divergent all external uses of values defined in \p DefCycle.
462   void analyzeCycleExitDivergence(const CycleT &DefCycle);
463 
464   /// \brief Mark as divergent all uses of \p I that are outside \p DefCycle.
465   void propagateTemporalDivergence(const InstructionT &I,
466                                    const CycleT &DefCycle);
467 
468   /// \brief Push all users of \p Val (in the region) to the worklist.
469   void pushUsers(const InstructionT &I);
470   void pushUsers(ConstValueRefT V);
471 
472   bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const;
473 
474   /// \brief Whether \p Def is divergent when read in \p ObservingBlock.
475   bool isTemporalDivergent(const BlockT &ObservingBlock,
476                            const InstructionT &Def) const;
477 };
478 
479 template <typename ImplT>
operator()480 void GenericUniformityAnalysisImplDeleter<ImplT>::operator()(ImplT *Impl) {
481   delete Impl;
482 }
483 
484 /// Compute divergence starting with a divergent branch.
485 template <typename ContextT> class DivergencePropagator {
486 public:
487   using BlockT = typename ContextT::BlockT;
488   using DominatorTreeT = typename ContextT::DominatorTreeT;
489   using FunctionT = typename ContextT::FunctionT;
490   using ValueRefT = typename ContextT::ValueRefT;
491 
492   using CycleInfoT = GenericCycleInfo<ContextT>;
493   using CycleT = typename CycleInfoT::CycleT;
494 
495   using ModifiedPO = ModifiedPostOrder<ContextT>;
496   using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
497   using DivergenceDescriptorT =
498       typename SyncDependenceAnalysisT::DivergenceDescriptor;
499   using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
500 
501   const ModifiedPO &CyclePOT;
502   const DominatorTreeT &DT;
503   const CycleInfoT &CI;
504   const BlockT &DivTermBlock;
505   const ContextT &Context;
506 
507   // Track blocks that receive a new label. Every time we relabel a
508   // cycle header, we another pass over the modified post-order in
509   // order to propagate the header label. The bit vector also allows
510   // us to skip labels that have not changed.
511   SparseBitVector<> FreshLabels;
512 
513   // divergent join and cycle exit descriptor.
514   std::unique_ptr<DivergenceDescriptorT> DivDesc;
515   BlockLabelMapT &BlockLabels;
516 
DivergencePropagator(const ModifiedPO & CyclePOT,const DominatorTreeT & DT,const CycleInfoT & CI,const BlockT & DivTermBlock)517   DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
518                        const CycleInfoT &CI, const BlockT &DivTermBlock)
519       : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
520         Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
521         BlockLabels(DivDesc->BlockLabels) {}
522 
printDefs(raw_ostream & Out)523   void printDefs(raw_ostream &Out) {
524     Out << "Propagator::BlockLabels {\n";
525     for (int BlockIdx = (int)CyclePOT.size() - 1; BlockIdx >= 0; --BlockIdx) {
526       const auto *Block = CyclePOT[BlockIdx];
527       const auto *Label = BlockLabels[Block];
528       Out << Context.print(Block) << "(" << BlockIdx << ") : ";
529       if (!Label) {
530         Out << "<null>\n";
531       } else {
532         Out << Context.print(Label) << "\n";
533       }
534     }
535     Out << "}\n";
536   }
537 
538   // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
539   // causes a divergent join.
computeJoin(const BlockT & SuccBlock,const BlockT & PushedLabel)540   bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
541     const auto *OldLabel = BlockLabels[&SuccBlock];
542 
543     LLVM_DEBUG(dbgs() << "labeling " << Context.print(&SuccBlock) << ":\n"
544                       << "\tpushed label: " << Context.print(&PushedLabel)
545                       << "\n"
546                       << "\told label: " << Context.print(OldLabel) << "\n");
547 
548     // Early exit if there is no change in the label.
549     if (OldLabel == &PushedLabel)
550       return false;
551 
552     if (OldLabel != &SuccBlock) {
553       auto SuccIdx = CyclePOT.getIndex(&SuccBlock);
554       // Assigning a new label, mark this in FreshLabels.
555       LLVM_DEBUG(dbgs() << "\tfresh label: " << SuccIdx << "\n");
556       FreshLabels.set(SuccIdx);
557     }
558 
559     // This is not a join if the succ was previously unlabeled.
560     if (!OldLabel) {
561       LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&PushedLabel)
562                         << "\n");
563       BlockLabels[&SuccBlock] = &PushedLabel;
564       return false;
565     }
566 
567     // This is a new join. Label the join block as itself, and not as
568     // the pushed label.
569     LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&SuccBlock) << "\n");
570     BlockLabels[&SuccBlock] = &SuccBlock;
571 
572     return true;
573   }
574 
575   // visiting a virtual cycle exit edge from the cycle header --> temporal
576   // divergence on join
visitCycleExitEdge(const BlockT & ExitBlock,const BlockT & Label)577   bool visitCycleExitEdge(const BlockT &ExitBlock, const BlockT &Label) {
578     if (!computeJoin(ExitBlock, Label))
579       return false;
580 
581     // Identified a divergent cycle exit
582     DivDesc->CycleDivBlocks.insert(&ExitBlock);
583     LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(&ExitBlock)
584                       << "\n");
585     return true;
586   }
587 
588   // process \p SuccBlock with reaching definition \p Label
visitEdge(const BlockT & SuccBlock,const BlockT & Label)589   bool visitEdge(const BlockT &SuccBlock, const BlockT &Label) {
590     if (!computeJoin(SuccBlock, Label))
591       return false;
592 
593     // Divergent, disjoint paths join.
594     DivDesc->JoinDivBlocks.insert(&SuccBlock);
595     LLVM_DEBUG(dbgs() << "\tDivergent join: " << Context.print(&SuccBlock)
596                       << "\n");
597     return true;
598   }
599 
computeJoinPoints()600   std::unique_ptr<DivergenceDescriptorT> computeJoinPoints() {
601     assert(DivDesc);
602 
603     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
604                       << Context.print(&DivTermBlock) << "\n");
605 
606     // Early stopping criterion
607     int FloorIdx = CyclePOT.size() - 1;
608     const BlockT *FloorLabel = nullptr;
609     int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
610 
611     // Bootstrap with branch targets
612     auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
613     for (const auto *SuccBlock : successors(&DivTermBlock)) {
614       if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) {
615         // If DivTerm exits the cycle immediately, computeJoin() might
616         // not reach SuccBlock with a different label. We need to
617         // check for this exit now.
618         DivDesc->CycleDivBlocks.insert(SuccBlock);
619         LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
620                           << Context.print(SuccBlock) << "\n");
621       }
622       auto SuccIdx = CyclePOT.getIndex(SuccBlock);
623       visitEdge(*SuccBlock, *SuccBlock);
624       FloorIdx = std::min<int>(FloorIdx, SuccIdx);
625     }
626 
627     while (true) {
628       auto BlockIdx = FreshLabels.find_last();
629       if (BlockIdx == -1 || BlockIdx < FloorIdx)
630         break;
631 
632       LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
633 
634       FreshLabels.reset(BlockIdx);
635       if (BlockIdx == DivTermIdx) {
636         LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
637         continue;
638       }
639 
640       const auto *Block = CyclePOT[BlockIdx];
641       LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
642                         << BlockIdx << "\n");
643 
644       const auto *Label = BlockLabels[Block];
645       assert(Label);
646 
647       bool CausedJoin = false;
648       int LoweredFloorIdx = FloorIdx;
649 
650       // If the current block is the header of a reducible cycle that
651       // contains the divergent branch, then the label should be
652       // propagated to the cycle exits. Such a header is the "last
653       // possible join" of any disjoint paths within this cycle. This
654       // prevents detection of spurious joins at the entries of any
655       // irreducible child cycles.
656       //
657       // This conclusion about the header is true for any choice of DFS:
658       //
659       //   If some DFS has a reducible cycle C with header H, then for
660       //   any other DFS, H is the header of a cycle C' that is a
661       //   superset of C. For a divergent branch inside the subgraph
662       //   C, any join node inside C is either H, or some node
663       //   encountered without passing through H.
664       //
665       auto getReducibleParent = [&](const BlockT *Block) -> const CycleT * {
666         if (!CyclePOT.isReducibleCycleHeader(Block))
667           return nullptr;
668         const auto *BlockCycle = CI.getCycle(Block);
669         if (BlockCycle->contains(&DivTermBlock))
670           return BlockCycle;
671         return nullptr;
672       };
673 
674       if (const auto *BlockCycle = getReducibleParent(Block)) {
675         SmallVector<BlockT *, 4> BlockCycleExits;
676         BlockCycle->getExitBlocks(BlockCycleExits);
677         for (auto *BlockCycleExit : BlockCycleExits) {
678           CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
679           LoweredFloorIdx =
680               std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
681         }
682       } else {
683         for (const auto *SuccBlock : successors(Block)) {
684           CausedJoin |= visitEdge(*SuccBlock, *Label);
685           LoweredFloorIdx =
686               std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
687         }
688       }
689 
690       // Floor update
691       if (CausedJoin) {
692         // 1. Different labels pushed to successors
693         FloorIdx = LoweredFloorIdx;
694       } else if (FloorLabel != Label) {
695         // 2. No join caused BUT we pushed a label that is different than the
696         // last pushed label
697         FloorIdx = LoweredFloorIdx;
698         FloorLabel = Label;
699       }
700     }
701 
702     LLVM_DEBUG(dbgs() << "Final labeling:\n"; printDefs(dbgs()));
703 
704     // Check every cycle containing DivTermBlock for exit divergence.
705     // A cycle has exit divergence if the label of an exit block does
706     // not match the label of its header.
707     for (const auto *Cycle = CI.getCycle(&DivTermBlock); Cycle;
708          Cycle = Cycle->getParentCycle()) {
709       if (Cycle->isReducible()) {
710         // The exit divergence of a reducible cycle is recorded while
711         // propagating labels.
712         continue;
713       }
714       SmallVector<BlockT *> Exits;
715       Cycle->getExitBlocks(Exits);
716       auto *Header = Cycle->getHeader();
717       auto *HeaderLabel = BlockLabels[Header];
718       for (const auto *Exit : Exits) {
719         if (BlockLabels[Exit] != HeaderLabel) {
720           // Identified a divergent cycle exit
721           DivDesc->CycleDivBlocks.insert(Exit);
722           LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(Exit)
723                             << "\n");
724         }
725       }
726     }
727 
728     return std::move(DivDesc);
729   }
730 };
731 
732 template <typename ContextT>
733 typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
734     llvm::GenericSyncDependenceAnalysis<ContextT>::EmptyDivergenceDesc;
735 
736 template <typename ContextT>
GenericSyncDependenceAnalysis(const ContextT & Context,const DominatorTreeT & DT,const CycleInfoT & CI)737 llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
738     const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
739     : CyclePO(Context), DT(DT), CI(CI) {
740   CyclePO.compute(CI);
741 }
742 
743 template <typename ContextT>
744 auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
745     const BlockT *DivTermBlock) -> const DivergenceDescriptor & {
746   // trivial case
747   if (succ_size(DivTermBlock) <= 1) {
748     return EmptyDivergenceDesc;
749   }
750 
751   // already available in cache?
752   auto ItCached = CachedControlDivDescs.find(DivTermBlock);
753   if (ItCached != CachedControlDivDescs.end())
754     return *ItCached->second;
755 
756   // compute all join points
757   DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
758   auto DivDesc = Propagator.computeJoinPoints();
759 
760   auto printBlockSet = [&](ConstBlockSet &Blocks) {
761     return Printable([&](raw_ostream &Out) {
762       Out << "[";
763       ListSeparator LS;
764       for (const auto *BB : Blocks) {
765         Out << LS << CI.getSSAContext().print(BB);
766       }
767       Out << "]\n";
768     });
769   };
770 
771   LLVM_DEBUG(
772       dbgs() << "\nResult (" << CI.getSSAContext().print(DivTermBlock)
773              << "):\n  JoinDivBlocks: " << printBlockSet(DivDesc->JoinDivBlocks)
774              << "  CycleDivBlocks: " << printBlockSet(DivDesc->CycleDivBlocks)
775              << "\n");
776   (void)printBlockSet;
777 
778   auto ItInserted =
779       CachedControlDivDescs.try_emplace(DivTermBlock, std::move(DivDesc));
780   assert(ItInserted.second);
781   return *ItInserted.first->second;
782 }
783 
784 template <typename ContextT>
markDivergent(const InstructionT & I)785 void GenericUniformityAnalysisImpl<ContextT>::markDivergent(
786     const InstructionT &I) {
787   if (isAlwaysUniform(I))
788     return;
789   bool Marked = false;
790   if (I.isTerminator()) {
791     Marked = DivergentTermBlocks.insert(I.getParent()).second;
792     if (Marked) {
793       LLVM_DEBUG(dbgs() << "marked divergent term block: "
794                         << Context.print(I.getParent()) << "\n");
795     }
796   } else {
797     Marked = markDefsDivergent(I);
798   }
799 
800   if (Marked)
801     Worklist.push_back(&I);
802 }
803 
804 template <typename ContextT>
markDivergent(ConstValueRefT Val)805 bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
806     ConstValueRefT Val) {
807   if (DivergentValues.insert(Val).second) {
808     LLVM_DEBUG(dbgs() << "marked divergent: " << Context.print(Val) << "\n");
809     return true;
810   }
811   return false;
812 }
813 
814 template <typename ContextT>
addUniformOverride(const InstructionT & Instr)815 void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
816     const InstructionT &Instr) {
817   UniformOverrides.insert(&Instr);
818 }
819 
820 // Mark as divergent all external uses of values defined in \p DefCycle.
821 //
822 // A value V defined by a block B inside \p DefCycle may be used outside the
823 // cycle only if the use is a PHI in some exit block, or B dominates some exit
824 // block. Thus, we check uses as follows:
825 //
826 // - Check all PHIs in all exit blocks for inputs defined inside \p DefCycle.
827 // - For every block B inside \p DefCycle that dominates at least one exit
828 //   block, check all uses outside \p DefCycle.
829 //
830 // FIXME: This function does not distinguish between divergent and uniform
831 // exits. For each divergent exit, only the values that are live at that exit
832 // need to be propagated as divergent at their use outside the cycle.
833 template <typename ContextT>
analyzeCycleExitDivergence(const CycleT & DefCycle)834 void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence(
835     const CycleT &DefCycle) {
836   SmallVector<BlockT *> Exits;
837   DefCycle.getExitBlocks(Exits);
838   for (auto *Exit : Exits) {
839     for (auto &Phi : Exit->phis()) {
840       if (usesValueFromCycle(Phi, DefCycle)) {
841         markDivergent(Phi);
842       }
843     }
844   }
845 
846   for (auto *BB : DefCycle.blocks()) {
847     if (!llvm::any_of(Exits,
848                      [&](BlockT *Exit) { return DT.dominates(BB, Exit); }))
849       continue;
850     for (auto &II : *BB) {
851       propagateTemporalDivergence(II, DefCycle);
852     }
853   }
854 }
855 
856 template <typename ContextT>
propagateCycleExitDivergence(const BlockT & DivExit,const CycleT & InnerDivCycle)857 void GenericUniformityAnalysisImpl<ContextT>::propagateCycleExitDivergence(
858     const BlockT &DivExit, const CycleT &InnerDivCycle) {
859   LLVM_DEBUG(dbgs() << "\tpropCycleExitDiv " << Context.print(&DivExit)
860                     << "\n");
861   auto *DivCycle = &InnerDivCycle;
862   auto *OuterDivCycle = DivCycle;
863   auto *ExitLevelCycle = CI.getCycle(&DivExit);
864   const unsigned CycleExitDepth =
865       ExitLevelCycle ? ExitLevelCycle->getDepth() : 0;
866 
867   // Find outer-most cycle that does not contain \p DivExit
868   while (DivCycle && DivCycle->getDepth() > CycleExitDepth) {
869     LLVM_DEBUG(dbgs() << "  Found exiting cycle: "
870                       << Context.print(DivCycle->getHeader()) << "\n");
871     OuterDivCycle = DivCycle;
872     DivCycle = DivCycle->getParentCycle();
873   }
874   LLVM_DEBUG(dbgs() << "\tOuter-most exiting cycle: "
875                     << Context.print(OuterDivCycle->getHeader()) << "\n");
876 
877   if (!DivergentExitCycles.insert(OuterDivCycle).second)
878     return;
879 
880   // Exit divergence does not matter if the cycle itself is assumed to
881   // be divergent.
882   for (const auto *C : AssumedDivergent) {
883     if (C->contains(OuterDivCycle))
884       return;
885   }
886 
887   analyzeCycleExitDivergence(*OuterDivCycle);
888 }
889 
890 template <typename ContextT>
taintAndPushAllDefs(const BlockT & BB)891 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs(
892     const BlockT &BB) {
893   LLVM_DEBUG(dbgs() << "taintAndPushAllDefs " << Context.print(&BB) << "\n");
894   for (const auto &I : instrs(BB)) {
895     // Terminators do not produce values; they are divergent only if
896     // the condition is divergent. That is handled when the divergent
897     // condition is placed in the worklist.
898     if (I.isTerminator())
899       break;
900 
901     markDivergent(I);
902   }
903 }
904 
905 /// Mark divergent phi nodes in a join block
906 template <typename ContextT>
taintAndPushPhiNodes(const BlockT & JoinBlock)907 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes(
908     const BlockT &JoinBlock) {
909   LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << Context.print(&JoinBlock)
910                     << "\n");
911   for (const auto &Phi : JoinBlock.phis()) {
912     // FIXME: The non-undef value is not constant per se; it just happens to be
913     // uniform and may not dominate this PHI. So assuming that the same value
914     // reaches along all incoming edges may itself be undefined behaviour. This
915     // particular interpretation of the undef value was added to
916     // DivergenceAnalysis in the following review:
917     //
918     // https://reviews.llvm.org/D19013
919     if (ContextT::isConstantOrUndefValuePhi(Phi))
920       continue;
921     markDivergent(Phi);
922   }
923 }
924 
925 /// Add \p Candidate to \p Cycles if it is not already contained in \p Cycles.
926 ///
927 /// \return true iff \p Candidate was added to \p Cycles.
928 template <typename CycleT>
insertIfNotContained(SmallVector<CycleT * > & Cycles,CycleT * Candidate)929 static bool insertIfNotContained(SmallVector<CycleT *> &Cycles,
930                                  CycleT *Candidate) {
931   if (llvm::any_of(Cycles,
932                    [Candidate](CycleT *C) { return C->contains(Candidate); }))
933     return false;
934   Cycles.push_back(Candidate);
935   return true;
936 }
937 
938 /// Return the outermost cycle made divergent by branch outside it.
939 ///
940 /// If two paths that diverged outside an irreducible cycle join
941 /// inside that cycle, then that whole cycle is assumed to be
942 /// divergent. This does not apply if the cycle is reducible.
943 template <typename CycleT, typename BlockT>
getExtDivCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock)944 static const CycleT *getExtDivCycle(const CycleT *Cycle,
945                                     const BlockT *DivTermBlock,
946                                     const BlockT *JoinBlock) {
947   assert(Cycle);
948   assert(Cycle->contains(JoinBlock));
949 
950   if (Cycle->contains(DivTermBlock))
951     return nullptr;
952 
953   const auto *OriginalCycle = Cycle;
954   const auto *Parent = Cycle->getParentCycle();
955   while (Parent && !Parent->contains(DivTermBlock)) {
956     Cycle = Parent;
957     Parent = Cycle->getParentCycle();
958   }
959 
960   // If the original cycle is not the outermost cycle, then the outermost cycle
961   // is irreducible. If the outermost cycle were reducible, then external
962   // diverged paths would not reach the original inner cycle.
963   (void)OriginalCycle;
964   assert(Cycle == OriginalCycle || !Cycle->isReducible());
965 
966   if (Cycle->isReducible()) {
967     assert(Cycle->getHeader() == JoinBlock);
968     return nullptr;
969   }
970 
971   LLVM_DEBUG(dbgs() << "cycle made divergent by external branch\n");
972   return Cycle;
973 }
974 
975 /// Return the outermost cycle made divergent by branch inside it.
976 ///
977 /// This checks the "diverged entry" criterion defined in the
978 /// docs/ConvergenceAnalysis.html.
979 template <typename ContextT, typename CycleT, typename BlockT,
980           typename DominatorTreeT>
981 static const CycleT *
getIntDivCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock,const DominatorTreeT & DT,ContextT & Context)982 getIntDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
983                const BlockT *JoinBlock, const DominatorTreeT &DT,
984                ContextT &Context) {
985   LLVM_DEBUG(dbgs() << "examine join " << Context.print(JoinBlock)
986                     << " for internal branch " << Context.print(DivTermBlock)
987                     << "\n");
988   if (DT.properlyDominates(DivTermBlock, JoinBlock))
989     return nullptr;
990 
991   // Find the smallest common cycle, if one exists.
992   assert(Cycle && Cycle->contains(JoinBlock));
993   while (Cycle && !Cycle->contains(DivTermBlock)) {
994     Cycle = Cycle->getParentCycle();
995   }
996   if (!Cycle || Cycle->isReducible())
997     return nullptr;
998 
999   if (DT.properlyDominates(Cycle->getHeader(), JoinBlock))
1000     return nullptr;
1001 
1002   LLVM_DEBUG(dbgs() << "  header " << Context.print(Cycle->getHeader())
1003                     << " does not dominate join\n");
1004 
1005   const auto *Parent = Cycle->getParentCycle();
1006   while (Parent && !DT.properlyDominates(Parent->getHeader(), JoinBlock)) {
1007     LLVM_DEBUG(dbgs() << "  header " << Context.print(Parent->getHeader())
1008                       << " does not dominate join\n");
1009     Cycle = Parent;
1010     Parent = Parent->getParentCycle();
1011   }
1012 
1013   LLVM_DEBUG(dbgs() << "  cycle made divergent by internal branch\n");
1014   return Cycle;
1015 }
1016 
1017 template <typename ContextT, typename CycleT, typename BlockT,
1018           typename DominatorTreeT>
1019 static const CycleT *
getOutermostDivergentCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock,const DominatorTreeT & DT,ContextT & Context)1020 getOutermostDivergentCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
1021                            const BlockT *JoinBlock, const DominatorTreeT &DT,
1022                            ContextT &Context) {
1023   if (!Cycle)
1024     return nullptr;
1025 
1026   // First try to expand Cycle to the largest that contains JoinBlock
1027   // but not DivTermBlock.
1028   const auto *Ext = getExtDivCycle(Cycle, DivTermBlock, JoinBlock);
1029 
1030   // Continue expanding to the largest cycle that contains both.
1031   const auto *Int = getIntDivCycle(Cycle, DivTermBlock, JoinBlock, DT, Context);
1032 
1033   if (Int)
1034     return Int;
1035   return Ext;
1036 }
1037 
1038 template <typename ContextT>
isTemporalDivergent(const BlockT & ObservingBlock,const InstructionT & Def)1039 bool GenericUniformityAnalysisImpl<ContextT>::isTemporalDivergent(
1040     const BlockT &ObservingBlock, const InstructionT &Def) const {
1041   const BlockT *DefBlock = Def.getParent();
1042   for (const CycleT *Cycle = CI.getCycle(DefBlock);
1043        Cycle && !Cycle->contains(&ObservingBlock);
1044        Cycle = Cycle->getParentCycle()) {
1045     if (DivergentExitCycles.contains(Cycle)) {
1046       return true;
1047     }
1048   }
1049   return false;
1050 }
1051 
1052 template <typename ContextT>
analyzeControlDivergence(const InstructionT & Term)1053 void GenericUniformityAnalysisImpl<ContextT>::analyzeControlDivergence(
1054     const InstructionT &Term) {
1055   const auto *DivTermBlock = Term.getParent();
1056   DivergentTermBlocks.insert(DivTermBlock);
1057   LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Context.print(DivTermBlock)
1058                     << "\n");
1059 
1060   // Don't propagate divergence from unreachable blocks.
1061   if (!DT.isReachableFromEntry(DivTermBlock))
1062     return;
1063 
1064   const auto &DivDesc = SDA.getJoinBlocks(DivTermBlock);
1065   SmallVector<const CycleT *> DivCycles;
1066 
1067   // Iterate over all blocks now reachable by a disjoint path join
1068   for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
1069     const auto *Cycle = CI.getCycle(JoinBlock);
1070     LLVM_DEBUG(dbgs() << "visiting join block " << Context.print(JoinBlock)
1071                       << "\n");
1072     if (const auto *Outermost = getOutermostDivergentCycle(
1073             Cycle, DivTermBlock, JoinBlock, DT, Context)) {
1074       LLVM_DEBUG(dbgs() << "found divergent cycle\n");
1075       DivCycles.push_back(Outermost);
1076       continue;
1077     }
1078     taintAndPushPhiNodes(*JoinBlock);
1079   }
1080 
1081   // Sort by order of decreasing depth. This allows later cycles to be skipped
1082   // because they are already contained in earlier ones.
1083   llvm::sort(DivCycles, [](const CycleT *A, const CycleT *B) {
1084     return A->getDepth() > B->getDepth();
1085   });
1086 
1087   // Cycles that are assumed divergent due to the diverged entry
1088   // criterion potentially contain temporal divergence depending on
1089   // the DFS chosen. Conservatively, all values produced in such a
1090   // cycle are assumed divergent. "Cycle invariant" values may be
1091   // assumed uniform, but that requires further analysis.
1092   for (auto *C : DivCycles) {
1093     if (!insertIfNotContained(AssumedDivergent, C))
1094       continue;
1095     LLVM_DEBUG(dbgs() << "process divergent cycle\n");
1096     for (const BlockT *BB : C->blocks()) {
1097       taintAndPushAllDefs(*BB);
1098     }
1099   }
1100 
1101   const auto *BranchCycle = CI.getCycle(DivTermBlock);
1102   assert(DivDesc.CycleDivBlocks.empty() || BranchCycle);
1103   for (const auto *DivExitBlock : DivDesc.CycleDivBlocks) {
1104     propagateCycleExitDivergence(*DivExitBlock, *BranchCycle);
1105   }
1106 }
1107 
1108 template <typename ContextT>
compute()1109 void GenericUniformityAnalysisImpl<ContextT>::compute() {
1110   // Initialize worklist.
1111   auto DivValuesCopy = DivergentValues;
1112   for (const auto DivVal : DivValuesCopy) {
1113     assert(isDivergent(DivVal) && "Worklist invariant violated!");
1114     pushUsers(DivVal);
1115   }
1116 
1117   // All values on the Worklist are divergent.
1118   // Their users may not have been updated yet.
1119   while (!Worklist.empty()) {
1120     const InstructionT *I = Worklist.back();
1121     Worklist.pop_back();
1122 
1123     LLVM_DEBUG(dbgs() << "worklist pop: " << Context.print(I) << "\n");
1124 
1125     if (I->isTerminator()) {
1126       analyzeControlDivergence(*I);
1127       continue;
1128     }
1129 
1130     // propagate value divergence to users
1131     assert(isDivergent(*I) && "Worklist invariant violated!");
1132     pushUsers(*I);
1133   }
1134 }
1135 
1136 template <typename ContextT>
isAlwaysUniform(const InstructionT & Instr)1137 bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
1138     const InstructionT &Instr) const {
1139   return UniformOverrides.contains(&Instr);
1140 }
1141 
1142 template <typename ContextT>
GenericUniformityInfo(const DominatorTreeT & DT,const CycleInfoT & CI,const TargetTransformInfo * TTI)1143 GenericUniformityInfo<ContextT>::GenericUniformityInfo(
1144     const DominatorTreeT &DT, const CycleInfoT &CI,
1145     const TargetTransformInfo *TTI) {
1146   DA.reset(new ImplT{DT, CI, TTI});
1147 }
1148 
1149 template <typename ContextT>
print(raw_ostream & OS)1150 void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
1151   bool haveDivergentArgs = false;
1152 
1153   // Control flow instructions may be divergent even if their inputs are
1154   // uniform. Thus, although exceedingly rare, it is possible to have a program
1155   // with no divergent values but with divergent control structures.
1156   if (DivergentValues.empty() && DivergentTermBlocks.empty() &&
1157       DivergentExitCycles.empty()) {
1158     OS << "ALL VALUES UNIFORM\n";
1159     return;
1160   }
1161 
1162   for (const auto &entry : DivergentValues) {
1163     const BlockT *parent = Context.getDefBlock(entry);
1164     if (!parent) {
1165       if (!haveDivergentArgs) {
1166         OS << "DIVERGENT ARGUMENTS:\n";
1167         haveDivergentArgs = true;
1168       }
1169       OS << "  DIVERGENT: " << Context.print(entry) << '\n';
1170     }
1171   }
1172 
1173   if (!AssumedDivergent.empty()) {
1174     OS << "CYCLES ASSSUMED DIVERGENT:\n";
1175     for (const CycleT *cycle : AssumedDivergent) {
1176       OS << "  " << cycle->print(Context) << '\n';
1177     }
1178   }
1179 
1180   if (!DivergentExitCycles.empty()) {
1181     OS << "CYCLES WITH DIVERGENT EXIT:\n";
1182     for (const CycleT *cycle : DivergentExitCycles) {
1183       OS << "  " << cycle->print(Context) << '\n';
1184     }
1185   }
1186 
1187   for (auto &block : F) {
1188     OS << "\nBLOCK " << Context.print(&block) << '\n';
1189 
1190     OS << "DEFINITIONS\n";
1191     SmallVector<ConstValueRefT, 16> defs;
1192     Context.appendBlockDefs(defs, block);
1193     for (auto value : defs) {
1194       if (isDivergent(value))
1195         OS << "  DIVERGENT: ";
1196       else
1197         OS << "             ";
1198       OS << Context.print(value) << '\n';
1199     }
1200 
1201     OS << "TERMINATORS\n";
1202     SmallVector<const InstructionT *, 8> terms;
1203     Context.appendBlockTerms(terms, block);
1204     bool divergentTerminators = hasDivergentTerminator(block);
1205     for (auto *T : terms) {
1206       if (divergentTerminators)
1207         OS << "  DIVERGENT: ";
1208       else
1209         OS << "             ";
1210       OS << Context.print(T) << '\n';
1211     }
1212 
1213     OS << "END BLOCK\n";
1214   }
1215 }
1216 
1217 template <typename ContextT>
hasDivergence()1218 bool GenericUniformityInfo<ContextT>::hasDivergence() const {
1219   return DA->hasDivergence();
1220 }
1221 
1222 template <typename ContextT>
1223 const typename ContextT::FunctionT &
getFunction()1224 GenericUniformityInfo<ContextT>::getFunction() const {
1225   return DA->getFunction();
1226 }
1227 
1228 /// Whether \p V is divergent at its definition.
1229 template <typename ContextT>
isDivergent(ConstValueRefT V)1230 bool GenericUniformityInfo<ContextT>::isDivergent(ConstValueRefT V) const {
1231   return DA->isDivergent(V);
1232 }
1233 
1234 template <typename ContextT>
isDivergent(const InstructionT * I)1235 bool GenericUniformityInfo<ContextT>::isDivergent(const InstructionT *I) const {
1236   return DA->isDivergent(*I);
1237 }
1238 
1239 template <typename ContextT>
isDivergentUse(const UseT & U)1240 bool GenericUniformityInfo<ContextT>::isDivergentUse(const UseT &U) const {
1241   return DA->isDivergentUse(U);
1242 }
1243 
1244 template <typename ContextT>
hasDivergentTerminator(const BlockT & B)1245 bool GenericUniformityInfo<ContextT>::hasDivergentTerminator(const BlockT &B) {
1246   return DA->hasDivergentTerminator(B);
1247 }
1248 
1249 /// \brief T helper function for printing.
1250 template <typename ContextT>
print(raw_ostream & out)1251 void GenericUniformityInfo<ContextT>::print(raw_ostream &out) const {
1252   DA->print(out);
1253 }
1254 
1255 template <typename ContextT>
computeStackPO(SmallVectorImpl<const BlockT * > & Stack,const CycleInfoT & CI,const CycleT * Cycle,SmallPtrSetImpl<const BlockT * > & Finalized)1256 void llvm::ModifiedPostOrder<ContextT>::computeStackPO(
1257     SmallVectorImpl<const BlockT *> &Stack, const CycleInfoT &CI,
1258     const CycleT *Cycle, SmallPtrSetImpl<const BlockT *> &Finalized) {
1259   LLVM_DEBUG(dbgs() << "inside computeStackPO\n");
1260   while (!Stack.empty()) {
1261     auto *NextBB = Stack.back();
1262     if (Finalized.count(NextBB)) {
1263       Stack.pop_back();
1264       continue;
1265     }
1266     LLVM_DEBUG(dbgs() << "  visiting " << CI.getSSAContext().print(NextBB)
1267                       << "\n");
1268     auto *NestedCycle = CI.getCycle(NextBB);
1269     if (Cycle != NestedCycle && (!Cycle || Cycle->contains(NestedCycle))) {
1270       LLVM_DEBUG(dbgs() << "  found a cycle\n");
1271       while (NestedCycle->getParentCycle() != Cycle)
1272         NestedCycle = NestedCycle->getParentCycle();
1273 
1274       SmallVector<BlockT *, 3> NestedExits;
1275       NestedCycle->getExitBlocks(NestedExits);
1276       bool PushedNodes = false;
1277       for (auto *NestedExitBB : NestedExits) {
1278         LLVM_DEBUG(dbgs() << "  examine exit: "
1279                           << CI.getSSAContext().print(NestedExitBB) << "\n");
1280         if (Cycle && !Cycle->contains(NestedExitBB))
1281           continue;
1282         if (Finalized.count(NestedExitBB))
1283           continue;
1284         PushedNodes = true;
1285         Stack.push_back(NestedExitBB);
1286         LLVM_DEBUG(dbgs() << "  pushed exit: "
1287                           << CI.getSSAContext().print(NestedExitBB) << "\n");
1288       }
1289       if (!PushedNodes) {
1290         // All loop exits finalized -> finish this node
1291         Stack.pop_back();
1292         computeCyclePO(CI, NestedCycle, Finalized);
1293       }
1294       continue;
1295     }
1296 
1297     LLVM_DEBUG(dbgs() << "  no nested cycle, going into DAG\n");
1298     // DAG-style
1299     bool PushedNodes = false;
1300     for (auto *SuccBB : successors(NextBB)) {
1301       LLVM_DEBUG(dbgs() << "  examine succ: "
1302                         << CI.getSSAContext().print(SuccBB) << "\n");
1303       if (Cycle && !Cycle->contains(SuccBB))
1304         continue;
1305       if (Finalized.count(SuccBB))
1306         continue;
1307       PushedNodes = true;
1308       Stack.push_back(SuccBB);
1309       LLVM_DEBUG(dbgs() << "  pushed succ: " << CI.getSSAContext().print(SuccBB)
1310                         << "\n");
1311     }
1312     if (!PushedNodes) {
1313       // Never push nodes twice
1314       LLVM_DEBUG(dbgs() << "  finishing node: "
1315                         << CI.getSSAContext().print(NextBB) << "\n");
1316       Stack.pop_back();
1317       Finalized.insert(NextBB);
1318       appendBlock(*NextBB);
1319     }
1320   }
1321   LLVM_DEBUG(dbgs() << "exited computeStackPO\n");
1322 }
1323 
1324 template <typename ContextT>
computeCyclePO(const CycleInfoT & CI,const CycleT * Cycle,SmallPtrSetImpl<const BlockT * > & Finalized)1325 void ModifiedPostOrder<ContextT>::computeCyclePO(
1326     const CycleInfoT &CI, const CycleT *Cycle,
1327     SmallPtrSetImpl<const BlockT *> &Finalized) {
1328   LLVM_DEBUG(dbgs() << "inside computeCyclePO\n");
1329   SmallVector<const BlockT *> Stack;
1330   auto *CycleHeader = Cycle->getHeader();
1331 
1332   LLVM_DEBUG(dbgs() << "  noted header: "
1333                     << CI.getSSAContext().print(CycleHeader) << "\n");
1334   assert(!Finalized.count(CycleHeader));
1335   Finalized.insert(CycleHeader);
1336 
1337   // Visit the header last
1338   LLVM_DEBUG(dbgs() << "  finishing header: "
1339                     << CI.getSSAContext().print(CycleHeader) << "\n");
1340   appendBlock(*CycleHeader, Cycle->isReducible());
1341 
1342   // Initialize with immediate successors
1343   for (auto *BB : successors(CycleHeader)) {
1344     LLVM_DEBUG(dbgs() << "  examine succ: " << CI.getSSAContext().print(BB)
1345                       << "\n");
1346     if (!Cycle->contains(BB))
1347       continue;
1348     if (BB == CycleHeader)
1349       continue;
1350     if (!Finalized.count(BB)) {
1351       LLVM_DEBUG(dbgs() << "  pushed succ: " << CI.getSSAContext().print(BB)
1352                         << "\n");
1353       Stack.push_back(BB);
1354     }
1355   }
1356 
1357   // Compute PO inside region
1358   computeStackPO(Stack, CI, Cycle, Finalized);
1359 
1360   LLVM_DEBUG(dbgs() << "exited computeCyclePO\n");
1361 }
1362 
1363 /// \brief Generically compute the modified post order.
1364 template <typename ContextT>
compute(const CycleInfoT & CI)1365 void llvm::ModifiedPostOrder<ContextT>::compute(const CycleInfoT &CI) {
1366   SmallPtrSet<const BlockT *, 32> Finalized;
1367   SmallVector<const BlockT *> Stack;
1368   auto *F = CI.getFunction();
1369   Stack.reserve(24); // FIXME made-up number
1370   Stack.push_back(&F->front());
1371   computeStackPO(Stack, CI, nullptr, Finalized);
1372 }
1373 
1374 } // namespace llvm
1375 
1376 #undef DEBUG_TYPE
1377 
1378 #endif // LLVM_ADT_GENERICUNIFORMITYIMPL_H
1379