1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2017-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 //
10 /// GenXSimdCFConformance
11 /// ---------------------
12 ///
13 /// This pass checks that the use of SIMD control flow (llvm.genx.simdcf.goto
14 /// and llvm.genx.simdcf.join) conforms to the rules required to allow us to
15 /// generate actual goto and join instructions. If not, the intrinsics are
16 /// lowered to code that implements the defined semantics for the intrinsics,
17 /// but does not use SIMD CF instructions, so is usually less efficient.
18 ///
19 /// It also makes certain transformations to make goto/join legal in terms of its
20 /// position in the basic block. These can fail silently, in which case the
21 /// conformance check will fail on the goto/join in question:
22 ///
23 /// * A goto and its extractvalues must be at the end of the block. (Actually, if
24 ///   the !any result of the goto is used in a conditional branch at the end of
25 ///   the block, then the goto being baled into the branch means that it is
26 ///   treated as being at the end of the block anyway. The only reason we need
27 ///   to sink it here is to ensure that isGotoBlock works.)
28 ///
29 /// * For a join label block (a block that is the JIP of other gotos/joins), a
30 ///   join must come at the start of the block.
31 ///
32 /// * For a branching join block (one whose conditional branch condition is the
33 ///   !any result from a join), the join must be at the end of the block.
34 ///
35 /// * For a block that has one join with both of the above true, we need to move
36 ///   all other code out of the block.
37 ///
38 /// The pass is run twice: an "early SIMD CF conformance pass" (a module pass)
39 /// just before GenXLowering, and a "late SIMD CF conformance pass" (a function
40 /// group pass) just before second baling.
41 ///
42 /// The early pass is the one that checks for conformance, and lowers the goto
43 /// and join intrinsics if the code is not conformant. The conformance checks
44 /// implement the rules listed in the documentation for the goto and join
45 /// intrinsics.
46 ///
47 /// Lowering a goto issues a "failed to optimize SIMD control flow" warning. No
48 /// clue is given in the warning as to what caused the conformance failure,
49 /// however you (a compiler developer) can find out (for a test case submitted
50 /// by a compiler user) by turning on -debug and looking at the output from this
51 /// pass.
52 ///
53 /// The late pass checks again for conformance, but if the code is not
54 /// conformant, it just errors. We could lower the gotos and joins there too,
55 /// but it would be more fiddly as we would have to ensure that the code
56 /// conforms with what is expected at that stage of compilation, and there is
57 /// no further chance to optimize it there.
58 ///
59 /// We are not expecting this error to happen.
60 ///
61 /// Otherwise, the late pass sets the register category of the EM and RM values
62 /// to "EM" and "RM", so they do not get any register allocated.
63 ///
64 /// Conformance rules
65 /// ^^^^^^^^^^^^^^^^^
66 ///
67 /// If the goto and join intrinsics are not used in a way that conforms to the
68 /// rules, then they will still have the semantics in their spec, but this pass
69 /// will lower at least some of them to equivalent but less efficient code.
70 ///
71 /// The rules are:
72 ///
73 /// 1. Because the hardware has a single EM (execution mask) register, all EM
74 ///    values input to and generated by these intrinsics must not interfere with
75 ///    each other; that is, they must have disjoint live ranges. For the
76 ///    purposes of determining interference, if any EM value is a phi node
77 ///    with incoming constant all ones, then the constant all ones value is
78 ///    counted as being live from the start of the function and is not allowed
79 ///    to interfere with other EM values (although it can interfere with other
80 ///    such constant all ones values).
81 ///
82 /// 2. An EM value is allowed to be defined:
83 ///
84 ///    a. as part of the struct returned by one of these intrinsics;
85 ///
86 ///    b. by a phi node, as long as each incoming is either an EM value or
87 ///       a constant all ones;
88 ///
89 ///    c. by an extractvalue extracting it from a struct containing an EM value;
90 ///
91 ///    d. as a function argument, as long as an EM value is also returned by the
92 ///       function (perhaps as part of a struct);
93 ///
94 ///    e. by an insertvalue as part of a return value struct;
95 ///
96 ///    f. as the return value of a non-intrinsic call (perhaps as part of a struct),
97 ///       as long as there is also a call arg that is an EM value, and the called
98 ///       function has the corresponding function arg and return value as EM values;
99 ///
100 ///    g. since shufflevector from EM does not change EM and only makes it shorter
101 ///       to create implicit predication of desired width, it's also considered
102 ///       as an EM definition, but it can only be used by wrregion and select;
103 ///
104 /// 3. An EM value is allowed to be used:
105 ///
106 ///    a. as the OldEM input to one of these intrinsics;
107 ///
108 ///    b. in a phi node, as long as the result of the phi node is an EM value;
109 ///
110 ///    c. as the condition in a wrregion or select;
111 ///
112 ///    d. as the input to a shufflevector whose effect is to slice part of the EM
113 ///       value starting at index 0, as long as the result of that slice is only
114 ///       used as the condition in a wrregion or select;
115 ///
116 ///    e. as a call argument, as long as the corresponding function argument is an
117 ///       EM value, and the call has an EM return value;
118 ///
119 ///    f. in a return (perhaps as part of a struct), as long as the function also
120 ///       has an argument that is an EM value.
121 ///
122 ///    For an EM value defined in a goto, or a join whose scalar BranchCond result
123 ///    is used in a conditional branch, or in an extractvalue out of
124 ///    the result of such a goto or join, the only use allowed in the same basic block
125 ///    as the goto/join is such an extractvalue.
126 ///
127 /// 4. The OldEM input to the two intrinsics must be either an EM value or
128 ///    constant all ones. In the latter case, and in the case of a constant incoming
129 ///    to an EM phi node, its live range is considered to reach
130 ///    back through all paths to the function entry for the purposes of rule (1).
131 ///
132 /// 5. Each join point has a web of RM (resume mask) values, linked as by rules (6)
133 ///    and (7). All RM values within one join point's web must not interfere with
134 ///    each other; that is, they must have disjoint live ranges. For the
135 ///    purposes of determining interference, if an RM value is a phi node with
136 ///    incoming constant all zeros, then the constant all zeros value is
137 ///    counted as being live from the start of the function and is not allowed
138 ///    to interfere with other RM values for this join (although it can
139 ///    interfere with other such constant all zeros values).
140 ///
141 /// 6. An RM value is allowed to be defined:
142 ///
143 ///    a. as part of the struct returned by ``llvm.genx.simdcf.goto``;
144 ///
145 ///    b. by a phi node, as long as each incoming is either an RM value or
146 ///       a constant all zeros.
147 ///
148 /// 7. An RM value is allowed to be used:
149 ///
150 ///    a. as the OldRM input to ``llvm.genx.simdcf.goto``;
151 ///
152 ///    b. as the RM input to ``llvm.genx.simdcf.join``, but only to one join in the
153 ///       whole web;
154 ///
155 ///    c. in a phi node, as long as the result of the phi node is an RM value.
156 ///
157 /// 8. The OldRM input to ``llvm.genx.simdcf.goto``, or the RM input to
158 ///    ``llvm.genx.simdcf.join``, must be either an RM value, or constant all
159 ///    zeros. In the latter case, and in the case of a constant incoming to an RM
160 ///    phi node, its live range is considered to reach back through all paths
161 ///    to the function entry or to the web's ``llvm.genx.simdcf.join`` for the
162 ///    purposes of rule (5).
163 ///
164 /// 9. The BranchCond struct element of the result of ``llvm.genx.simdcf.goto``
165 ///    must either be unused (unextracted), or, after being extractvalued,
166 ///    must have exactly one use, which is in a
167 ///    conditional branch terminating the same basic block. In the unused case,
168 ///    the basic block must end with an unconditional branch. (This is a goto
169 ///    that is immediately followed by a join.)
170 ///
171 /// 10. The BranchCond struct element of the result of ``llvm.genx.simdcf.join``
172 ///     must either be unused (unextracted), or, after being extractvalued,
173 ///     have exactly one use, which is in a conditional branch terminating the
174 ///     same basic block.
175 ///
176 /// 11. It must be possible to derive an ordering for the basic blocks in a
177 ///     function such that, in the conditional branch using the result of any goto
178 ///     or join, the "false" successor is fall-through and the "true" successor is
179 ///     to a join later on in the sequence. For a goto followed by an
180 ///     unconditional branch, the successor is fall-through _and_ the next join
181 ///     in sequence.
182 ///
183 /// **IR restriction**: goto and join intrinsics must conform to these rules
184 /// (since this pass lowers any that do not).
185 ///
186 //===----------------------------------------------------------------------===//
187 #define DEBUG_TYPE "GENX_SIMDCFCONFORMANCE"
188 
189 #include "FunctionGroup.h"
190 #include "GenX.h"
191 #include "GenXConstants.h"
192 #include "GenXGotoJoin.h"
193 #include "GenXLiveness.h"
194 #include "GenXModule.h"
195 #include "GenXTargetMachine.h"
196 #include "GenXUtil.h"
197 #include "vc/GenXOpts/Utils/KernelInfo.h"
198 #include "vc/GenXOpts/Utils/RegCategory.h"
199 
200 #include "llvm/ADT/MapVector.h"
201 #include "llvm/ADT/PostOrderIterator.h"
202 #include "llvm/ADT/SetVector.h"
203 #include "llvm/ADT/SmallSet.h"
204 #include "llvm/Analysis/CFG.h"
205 #include "llvm/CodeGen/TargetPassConfig.h"
206 #include "llvm/GenXIntrinsics/GenXIntrinsics.h"
207 #include "llvm/IR/Constants.h"
208 #include "llvm/IR/DiagnosticInfo.h"
209 #include "llvm/IR/DiagnosticPrinter.h"
210 #include "llvm/IR/Dominators.h"
211 #include "llvm/IR/IRBuilder.h"
212 #include "llvm/IR/Instructions.h"
213 #include "llvm/IR/LLVMContext.h"
214 #include "llvm/IR/PatternMatch.h"
215 #include "llvm/Support/CommandLine.h"
216 #include "llvm/Support/Debug.h"
217 #include "llvm/Transforms/Utils/Local.h"
218 
219 #include "Probe/Assertion.h"
220 #include "llvmWrapper/IR/DerivedTypes.h"
221 #include "llvmWrapper/IR/InstrTypes.h"
222 #include "llvmWrapper/IR/IntrinsicInst.h"
223 #include "llvmWrapper/Support/TypeSize.h"
224 
225 using namespace llvm;
226 using namespace genx;
227 
228 static cl::opt<bool> EnableGenXGotoJoin("enable-genx-goto-join", cl::init(true), cl::Hidden,
229                                         cl::desc("Enable use of Gen goto/join instructions for SIMD control flow."));
230 
231 namespace {
232 
233 // Diagnostic information for error/warning relating to SIMD control flow.
234 class DiagnosticInfoSimdCF : public DiagnosticInfoOptimizationBase {
235 private:
236   static int KindID;
getKindID()237   static int getKindID() {
238     if (KindID == 0)
239       KindID = llvm::getNextAvailablePluginDiagnosticKind();
240     return KindID;
241   }
242 public:
243   static void emit(Instruction *Inst, StringRef Msg, DiagnosticSeverity Severity = DS_Error);
DiagnosticInfoSimdCF(DiagnosticSeverity Severity,const Function & Fn,const DebugLoc & DLoc,StringRef Msg)244   DiagnosticInfoSimdCF(DiagnosticSeverity Severity, const Function &Fn,
245       const DebugLoc &DLoc, StringRef Msg)
246       : DiagnosticInfoOptimizationBase((DiagnosticKind)getKindID(), Severity,
247           /*PassName=*/nullptr, Msg, Fn, DLoc) {}
248   // This kind of message is always enabled, and not affected by -rpass.
isEnabled() const249   bool isEnabled() const override { return true; }
classof(const DiagnosticInfo * DI)250   static bool classof(const DiagnosticInfo *DI) {
251     return DI->getKind() == getKindID();
252   }
253 
254   // TODO: consider changing format
print(DiagnosticPrinter & DP) const255   void print(DiagnosticPrinter &DP) const override { DP << "GenXSimdCFConformance: " << RemarkName; }
256 };
257 int DiagnosticInfoSimdCF::KindID = 0;
258 
259 // GenX SIMD control flow conformance pass -- common data between early and
260 // late passes.
261 class GenXSimdCFConformance {
262 protected:
263   Module *M = nullptr;
264   FunctionGroup *FG = nullptr;
265   FunctionGroupAnalysis *FGA = nullptr;
266   DominatorTreeGroupWrapperPass *DTWrapper = nullptr;
267   std::map<Function *, DominatorTree *> DTs;
268   GenXLiveness *Liveness = nullptr;
269   bool Modified = false;
270   SetVector<SimpleValue> EMVals;
271   std::map<CallInst *, SetVector<SimpleValue>> RMVals;
272   bool lowerSimdCF = false;
273 private:
274 
275   // GotoJoinEVs: container for goto/join Extract Value (EV) info. Also
276   // allowes to remove duplication of EVs. Performs it in construction
277   // and moves EVs right after goto/join. Hoisting can be performed
278   // again with hoistEVs method. For instance, it is used on join
279   // hoisting to save correct EM liveranges.
280   class GotoJoinEVs {
281   private:
282     enum ValPos {
283       EMPos = 0,
284       RMPos = 1,
285       JoinCondPos = 1,
286       GotoCondPos = 2,
287       PosNum
288     };
289 
290     bool testPosCorrectness(const unsigned Index) const;
291 
292     ExtractValueInst *EVs[PosNum] = { nullptr, nullptr, nullptr };
293     bool IsGoto;
294     Value *GotoJoin;
295 
296     void CollectEVs();
297 
298   public:
299 
300     GotoJoinEVs(Value *GJ = nullptr);
301     ExtractValueInst *getEMEV() const;
302     ExtractValueInst *getRMEV() const;
303     ExtractValueInst *getCondEV() const;
304     Value *getGotoJoin() const;
305     Instruction *getSplitPoint() const;
306     void setCondEV(ExtractValueInst *CondEV);
307     bool isGoto() const;
308     bool isJoin() const;
309     void hoistEVs() const;
310 
311   };
312 
313   class JoinPointOptData {
314   private:
315     BasicBlock *FalsePred;
316     Instruction *EM;
317 
318   public:
JoinPointOptData(BasicBlock * FalsePred=nullptr,Instruction * EM=nullptr)319     JoinPointOptData(BasicBlock *FalsePred = nullptr, Instruction *EM = nullptr)
320         : FalsePred(FalsePred), EM(EM) {}
getTruePred() const321     BasicBlock *getTruePred() const { return EM->getParent(); }
getFalsePred() const322     BasicBlock *getFalsePred() const { return FalsePred; }
getRealEM() const323     Instruction *getRealEM() const { return EM; }
324   };
325 
326   SetVector<SimpleValue> EMValsStack;
327   MapVector<Value *, GotoJoinEVs> GotoJoinEVsMap;
328   MapVector<BasicBlock *, JoinPointOptData> BlocksToOptimize;
329   std::map<CallInst *, CallInst *> GotoJoinMap;
330   std::map<Value *, Value *> EMProducers;
331   std::map<Value *, Value *> LoweredEMValsMap;
332 
333 protected:
GenXSimdCFConformance()334   GenXSimdCFConformance() :
335     M(0), FG(0), FGA(0), DTWrapper(0), Liveness(0), lowerSimdCF(false) {}
336   void gatherEMVals();
337   void gatherRMVals();
338   void removeFromEMRMVals(Value *V);
339   void moveCodeInGotoBlocks(bool hoistGotoUsers = false);
340   void moveCodeInJoinBlocks();
341   void ensureConformance();
342   void lowerAllSimdCF();
343   void canonicalizeEM();
344   void splitGotoJoinBlocks();
345   void lowerUnsuitableGetEMs();
346   void optimizeRestoredSIMDCF();
clear()347   void clear() {
348     DTs.clear();
349     EMVals.clear();
350     RMVals.clear();
351     GotoJoinMap.clear();
352     GotoJoinEVsMap.clear();
353     EMProducers.clear();
354     LoweredEMValsMap.clear();
355     BlocksToOptimize.clear();
356   }
357   DominatorTree *getDomTree(Function *F);
358 private:
isLatePass() const359   bool isLatePass() const { return FG != nullptr; }
360   void emptyBranchingJoinBlocksInFunc(Function *F);
361   void emptyBranchingJoinBlock(CallInst *Join);
362   bool hoistJoin(CallInst *Join);
363   bool checkEMVal(SimpleValue EMVal);
364   bool checkGoto(SimpleValue EMVal);
365   bool checkJoin(SimpleValue EMVal);
366   bool checkGotoJoin(SimpleValue EMVal);
367   void removeBadEMVal(SimpleValue EMVal);
368   void pushValues(Value *V);
369   bool getConnectedVals(SimpleValue Val, int Cat, bool IncludeOptional, CallInst *OkJoin, SmallVectorImpl<SimpleValue> *ConnectedVals, bool LowerBadUsers = false);
370   void checkEMInterference();
371   void checkInterference(SetVector<SimpleValue> *Vals, SetVector<Value *> *BadDefs, Instruction *ConstStop);
372   bool hoistGotoUser(Instruction *Inst, CallInst *Goto, unsigned operandNo);
373   void gatherGotoJoinEMVals(bool IncludeIncoming = true);
374   void handleEVs();
375   void resolveBitCastChains();
376   Value *eliminateBitCastPreds(Value *Val, std::set<Value *> &DeadInst, std::set<Value *> &Visited);
377   Value *getEMProducer(Value *Inst, std::set<Value *> &Visited, bool BitCastAllowed = false);
378   void handleCondValue(Value *GotoJoin);
379   void handleNoCondEVCase(GotoJoinEVs &GotoJoinData);
380   void handleOptimizedBranchCase(GotoJoinEVs &GotoJoinData, BasicBlock *&TrueSucc, BasicBlock *&FalseSucc);
381   void handleExistingBranchCase(GotoJoinEVs &GotoJoinData, BasicBlock *&TrueSucc, BasicBlock *&FalseSucc, BranchInst *ExistingBranch);
382   void addNewPhisIncomings(BasicBlock *BranchingBlock, BasicBlock *TrueSucc, BasicBlock *FalseSucc);
383   void collectCondEVUsers(ExtractValueInst *CondEV, std::vector<Value *> &BadUsers, BranchInst *&CorrectUser);
384   void updateBadCondEVUsers(GotoJoinEVs &GotoJoinData, std::vector<Value *> &BadUsers, BasicBlock *TrueSucc, BasicBlock *FalseSucc);
385   Value *findGotoJoinVal(int Cat, BasicBlock *Loc, Instruction *CondEV, BasicBlockEdge &TrueEdge, BasicBlockEdge &FalseEdge, Value *TrueVal,
386     Value *FalseVal, std::map<BasicBlock *, Value *> &foundVals);
387   bool canUseLoweredEM(Instruction *Val);
388   bool canUseRealEM(Instruction *Inst, unsigned opNo);
389   void replaceUseWithLoweredEM(Instruction *Val, unsigned opNo, SetVector<Value *> &ToRemove);
390   Value *findLoweredEMValue(Value *Val);
391   Value *buildLoweringViaGetEM(Value *Val, Instruction *InsertBefore);
392   Value *getGetEMLoweredValue(Value *Val, Instruction *InsertBefore);
393   Value *lowerEVIUse(ExtractValueInst *EVI, Instruction *User,
394                      BasicBlock *PhiPredBlock = nullptr);
395   Value *lowerPHIUse(PHINode *PN, SetVector<Value *> &ToRemove);
396   Value *lowerArgumentUse(Argument *Arg);
397   Value *insertCond(Value *OldVal, Value *NewVal, const Twine &Name, Instruction *InsertBefore, const DebugLoc &DL);
398   Value *truncateCond(Value *In, Type *Ty, const Twine &Name, Instruction *InsertBefore, const DebugLoc &DL);
399   void lowerGoto(CallInst *Goto);
400   void lowerJoin(CallInst *Join);
401   void replaceGotoJoinUses(CallInst *GotoJoin, ArrayRef<Value *> Vals);
402   void optimizeLinearization(BasicBlock *BB, JoinPointOptData &JPData);
403   bool isActualStoredEM(Instruction *Inst, JoinPointOptData &JPData);
404   bool canBeMovedUnderSIMDCF(Value *Val, BasicBlock *CurrBB,
405                              JoinPointOptData &JPData,
406                              std::set<Instruction *> &Visited);
407   bool isSelectConditionCondEV(SelectInst *Sel, JoinPointOptData &JPData);
408   void replaceGetEMUse(Instruction *Inst, JoinPointOptData &JPData);
409 };
410 
411 // GenX early SIMD control flow conformance pass
412 class GenXEarlySimdCFConformance
413     : public GenXSimdCFConformance, public ModulePass {
414 public:
415   static char ID;
GenXEarlySimdCFConformance()416   explicit GenXEarlySimdCFConformance() : ModulePass(ID) { }
getPassName() const417   StringRef getPassName() const override {
418     return "GenX early SIMD control flow conformance";
419   }
getAnalysisUsage(AnalysisUsage & AU) const420   void getAnalysisUsage(AnalysisUsage &AU) const override {
421     ModulePass::getAnalysisUsage(AU);
422   }
423   bool runOnModule(Module &M) override;
424 };
425 
426 // GenX late SIMD control flow conformance pass
427 class GenXLateSimdCFConformance : public FGPassImplInterface,
428                                   public IDMixin<GenXLateSimdCFConformance>,
429                                   public GenXSimdCFConformance {
430 public:
GenXLateSimdCFConformance()431   explicit GenXLateSimdCFConformance() {}
getPassName()432   static StringRef getPassName() {
433     return "GenX late SIMD control flow conformance";
434   }
getAnalysisUsage(AnalysisUsage & AU)435   static void getAnalysisUsage(AnalysisUsage &AU) {
436     AU.addRequired<DominatorTreeGroupWrapperPass>();
437     AU.addRequired<GenXLiveness>();
438     AU.addRequired<TargetPassConfig>();
439     AU.addPreserved<GenXModule>();
440     AU.addPreserved<GenXLiveness>();
441     AU.addPreserved<FunctionGroupAnalysis>();
442   }
releaseMemory()443   void releaseMemory() override { clear(); }
444   bool runOnFunctionGroup(FunctionGroup &FG) override;
445 
446 private:
447   void setCategories();
448   void modifyEMUses(Value *EM);
449 };
450 
451 /***********************************************************************
452  * Local function for testing one assertion statement.
453  * It returns true if intrinsic is GOTO or JOIN as expected.
454  */
testIsGotoJoin(const llvm::Value * const GotoJoin)455 bool testIsGotoJoin(const llvm::Value *const GotoJoin) {
456   bool Result = false;
457   IGC_ASSERT(GotoJoin);
458   const llvm::GenXIntrinsic::ID ID = llvm::GenXIntrinsic::getGenXIntrinsicID(GotoJoin);
459   switch(ID)
460   {
461     case llvm::GenXIntrinsic::genx_simdcf_goto:
462     case llvm::GenXIntrinsic::genx_simdcf_join:
463       Result = true;
464       break;
465     default:
466       IGC_ASSERT(0);
467       Result = false;
468       break;
469   }
470   return Result;
471 }
472 
473 /***********************************************************************
474  * Local function for testing one assertion statement.
475  * It returns true if intrinsic is JOIN as expected.
476  */
testIsJoin(const llvm::Value * const GotoJoin)477 bool testIsJoin(const llvm::Value *const GotoJoin) {
478   bool Result = false;
479   IGC_ASSERT(GotoJoin);
480   const llvm::GenXIntrinsic::ID ID = llvm::GenXIntrinsic::getGenXIntrinsicID(GotoJoin);
481   switch(ID)
482   {
483     case llvm::GenXIntrinsic::genx_simdcf_join:
484       Result = true;
485       break;
486     default:
487       IGC_ASSERT(0);
488       Result = false;
489       break;
490   }
491   return Result;
492 }
493 
494 /***********************************************************************
495  * Local function for testing one assertion statement.
496  * It returns true if all is ok.
497  */
testIsValidEMUse(const llvm::Value * const User,const llvm::Value::use_iterator & ui)498 bool testIsValidEMUse(const llvm::Value *const User,
499   const llvm::Value::use_iterator& ui) {
500   bool Result = false;
501   IGC_ASSERT(User);
502   const unsigned int ID = llvm::GenXIntrinsic::getAnyIntrinsicID(User);
503   switch(ID)
504   {
505     case llvm::GenXIntrinsic::genx_rdpredregion:
506     case llvm::GenXIntrinsic::genx_simdcf_goto:
507     case llvm::GenXIntrinsic::genx_simdcf_join:
508     case llvm::GenXIntrinsic::genx_simdcf_get_em:
509     case llvm::GenXIntrinsic::genx_wrpredpredregion:
510       Result = true;
511       break;
512     case llvm::GenXIntrinsic::genx_wrregioni:
513     case llvm::GenXIntrinsic::genx_wrregionf:
514       Result = (ui->getOperandNo() ==
515         llvm::GenXIntrinsic::GenXRegion::PredicateOperandNum);
516       IGC_ASSERT(Result);
517       break;
518     case llvm::GenXIntrinsic::not_any_intrinsic:
519       Result = (isa<PHINode>(User) ||
520         isa<InsertValueInst>(User) ||
521         isa<CallInst>(User) ||
522         isa<ReturnInst>(User) ||
523         isa<ShuffleVectorInst>(User));
524       IGC_ASSERT_MESSAGE(Result, "unexpected use of EM");
525       break;
526     default:
527       Result = (isa<ReturnInst>(User) ||
528         isa<InsertValueInst>(User) ||
529         isa<ExtractValueInst>(User) ||
530         !cast<CallInst>(User)->getCalledFunction()->doesNotAccessMemory());
531       IGC_ASSERT_MESSAGE(Result, "unexpected ALU intrinsic use of EM");
532       break;
533   }
534   return Result;
535 }
536 
537 /***********************************************************************
538  * Local function for testing one assertion statement.
539  * It returns true if Pos is correct.
540  */
testPosCorrectness(const unsigned Index) const541 bool GenXSimdCFConformance::GotoJoinEVs::testPosCorrectness(
542   const unsigned Index) const {
543   bool Result = false;
544   switch (Index)
545   {
546     case EMPos:
547     case RMPos: // same as JoinCondPos
548       Result = true;
549       break;
550     case GotoCondPos:
551       Result = IsGoto;
552       IGC_ASSERT_MESSAGE(Result, "Bad index in ExtractValue for goto/join!");
553       break;
554     default:
555       Result = false;
556       IGC_ASSERT_MESSAGE(0, "Bad index in ExtractValue for goto/join!");
557       break;
558   }
559   return Result;
560 }
561 
562 } // end anonymous namespace
563 
564 char GenXEarlySimdCFConformance::ID = 0;
565 namespace llvm { void initializeGenXEarlySimdCFConformancePass(PassRegistry &); }
566 INITIALIZE_PASS_BEGIN(GenXEarlySimdCFConformance, "GenXEarlySimdCFConformance", "GenXEarlySimdCFConformance", false, false)
567 INITIALIZE_PASS_END(GenXEarlySimdCFConformance, "GenXEarlySimdCFConformance", "GenXEarlySimdCFConformance", false, false)
568 
createGenXEarlySimdCFConformancePass()569 ModulePass *llvm::createGenXEarlySimdCFConformancePass()
570 {
571   initializeGenXEarlySimdCFConformancePass(*PassRegistry::getPassRegistry());
572   return new GenXEarlySimdCFConformance();
573 }
574 
575 namespace llvm {
576 void initializeGenXLateSimdCFConformanceWrapperPass(PassRegistry &);
577 using GenXLateSimdCFConformanceWrapper =
578     FunctionGroupWrapperPass<GenXLateSimdCFConformance>;
579 } // namespace llvm
580 INITIALIZE_PASS_BEGIN(GenXLateSimdCFConformanceWrapper,
581                       "GenXLateSimdCFConformanceWrapper",
582                       "GenXLateSimdCFConformanceWrapper", false, false)
INITIALIZE_PASS_DEPENDENCY(FunctionGroupAnalysis)583 INITIALIZE_PASS_DEPENDENCY(FunctionGroupAnalysis)
584 INITIALIZE_PASS_DEPENDENCY(DominatorTreeGroupWrapperPassWrapper)
585 INITIALIZE_PASS_DEPENDENCY(GenXLivenessWrapper)
586 INITIALIZE_PASS_DEPENDENCY(GenXModule)
587 INITIALIZE_PASS_END(GenXLateSimdCFConformanceWrapper,
588                     "GenXLateSimdCFConformanceWrapper",
589                     "GenXLateSimdCFConformanceWrapper", false, false)
590 
591 ModulePass *llvm::createGenXLateSimdCFConformanceWrapperPass() {
592   initializeGenXLateSimdCFConformanceWrapperPass(
593       *PassRegistry::getPassRegistry());
594   return new GenXLateSimdCFConformanceWrapper();
595 }
596 
hasStackCall(const Module & M)597 static bool hasStackCall(const Module &M) {
598   return std::any_of(M.begin(), M.end(),
599                      [](const auto &F) { return genx::requiresStackCall(&F); });
600 }
601 
602 /***********************************************************************
603  * runOnModule : run the early SIMD control flow conformance pass for this
604  *  module
605  */
runOnModule(Module & ArgM)606 bool GenXEarlySimdCFConformance::runOnModule(Module &ArgM)
607 {
608   LLVM_DEBUG(dbgs() << "Early SIMD CF Conformance starts\n");
609 
610   Modified = false;
611   M = &ArgM;
612   FG = nullptr;
613   FGA = nullptr;
614   DTWrapper = nullptr;
615   lowerSimdCF = hasStackCall(ArgM);
616   // Perform actions to create correct DF for EM
617   canonicalizeEM();
618   // Gather the EM values, both from goto/join and phi nodes.
619   gatherEMVals();
620   // Gather the RM values from gotos and phi nodes.
621   gatherRMVals();
622   // Hoist instructions that does not depend on Goto's result.
623   // It is needed to perform correct split.
624   moveCodeInGotoBlocks();
625   // Split Goto/Join blocks to recreate actual SIMD CF
626   splitGotoJoinBlocks();
627   // Handle instructions that depend on Goto's result
628   moveCodeInGotoBlocks(true);
629   // Handle Joins to create correct SIMD CF structure
630   moveCodeInJoinBlocks();
631   // TODO: currently all SIMD CF is lowered if there is
632   // an unmask construction in module. It is very suboptimal.
633   if (lowerSimdCF)
634     lowerAllSimdCF();
635   else {
636     // Repeatedly check the code for conformance and lower non-conformant gotos
637     // and joins until the code stabilizes.
638     ensureConformance();
639     optimizeRestoredSIMDCF();
640   }
641   // Perform check for genx_simdcf_get_em intrinsics and remove redundant ones.
642   lowerUnsuitableGetEMs();
643   clear();
644 
645   LLVM_DEBUG(dbgs() << "Early SIMD CF Conformance ends\n");
646 
647   return Modified;
648 }
649 
650 /***********************************************************************
651  * runOnFunctionGroup : run the late SIMD control flow conformance pass for this
652  * FunctionGroup
653  */
runOnFunctionGroup(FunctionGroup & ArgFG)654 bool GenXLateSimdCFConformance::runOnFunctionGroup(FunctionGroup &ArgFG)
655 {
656   LLVM_DEBUG(dbgs() << "Late SIMD CF Conformance starts\n");
657 
658   Modified = false;
659   FG = &ArgFG;
660   M = FG->getModule();
661   // Get analyses that we use and/or modify.
662   FGA = &getAnalysis<FunctionGroupAnalysis>();
663   DTWrapper = &getAnalysis<DominatorTreeGroupWrapperPass>();
664   Liveness = &getAnalysis<GenXLiveness>();
665   // Gather the EM values, both from goto/join and phi nodes.
666   gatherEMVals();
667   // Gather the RM values from gotos and phi nodes.
668   gatherRMVals();
669   // Move code in goto and join blocks as necessary.
670   moveCodeInGotoBlocks();
671   moveCodeInJoinBlocks();
672   // Check the code for conformance. In this late pass, we do not expect to
673   // find non-conformance.
674   ensureConformance();
675   // For remaining unlowered gotos and joins (the ones that will become SIMD
676   // control flow instructions), mark the webs of EM and RM values as
677   // category EM or RM respectively. For EM, this also modifies uses as needed.
678   setCategories();
679   clear();
680 
681   LLVM_DEBUG(dbgs() << "Late SIMD CF Conformance ends\n");
682 
683   return Modified;
684 }
685 
686 /***********************************************************************
687  * gatherGotoJoinEMVals : gather the EM values for gotos/joins only
688  *
689  * IncludeIncoming is used for adding goto/join def to EMVals
690  */
gatherGotoJoinEMVals(bool IncludeIncoming)691 void GenXSimdCFConformance::gatherGotoJoinEMVals(bool IncludeIncoming)
692 {
693   // We find gotos and joins by scanning all uses of the intrinsics and (in the
694   // case of the late pass) ignoring ones not in this function group, rather
695   // than scanning the whole IR.
696   Type *I1Ty = Type::getInt1Ty(M->getContext());
697   for (auto IID : { GenXIntrinsic::genx_simdcf_goto, GenXIntrinsic::genx_simdcf_join }) {
698     Type *EMTy = IGCLLVM::FixedVectorType::get(I1Ty, 32);
699     for (unsigned Width = 1; Width <= 32; Width <<= 1) {
700       Type *Tys[] = {EMTy, IGCLLVM::FixedVectorType::get(I1Ty, Width)};
701       auto GotoJoinFunc = GenXIntrinsic::getGenXDeclaration(M, IID, Tys);
702       for (auto ui = GotoJoinFunc->use_begin(), ue = GotoJoinFunc->use_end();
703           ui != ue; ++ui) {
704         auto GotoJoin = dyn_cast<CallInst>(ui->getUser());
705         if (!GotoJoin)
706           continue;
707         if (FG && (FGA->getGroup(GotoJoin->getParent()->getParent()) != FG
708             || ui->getOperandNo() != GotoJoin->getNumArgOperands()))
709           continue;
710         // We have a goto/join (in our function group in the case of the late
711         // pass).  Add the EM value (struct index 0) to EMVals.
712         EMVals.insert(SimpleValue(GotoJoin, 0));
713         // Also add its EM input to EMVals, if not a constant.
714         if (IncludeIncoming && !isa<Constant>(GotoJoin->getOperand(0)))
715           EMVals.insert(SimpleValue(GotoJoin->getOperand(0), 0));
716       }
717     }
718   }
719 }
720 
721 /***********************************************************************
722  * gatherEMVals : gather the EM values, including phi nodes
723  */
gatherEMVals()724 void GenXSimdCFConformance::gatherEMVals()
725 {
726   // Collect gotos/joins and their defs
727   gatherGotoJoinEMVals(true);
728 
729   Type *I1Ty = Type::getInt1Ty(M->getContext());
730   Type *EMTy = IGCLLVM::FixedVectorType::get(I1Ty, 32);
731   Type *Tys[] = { EMTy };
732   auto SavemaskFunc = GenXIntrinsic::getGenXDeclaration(
733       M, GenXIntrinsic::genx_simdcf_savemask, Tys);
734   for (auto ui = SavemaskFunc->use_begin(), ue = SavemaskFunc->use_end(); ui != ue;
735        ++ui) {
736     auto Savemask = dyn_cast<CallInst>(ui->getUser());
737     if (!Savemask)
738       continue;
739     if (FG && (FGA->getGroup(Savemask->getParent()->getParent()) != FG ||
740                ui->getOperandNo() != Savemask->getNumArgOperands()))
741       continue;
742       lowerSimdCF = true;
743     // Add its EM input to EMVals, if not a constant.
744     if (!isa<Constant>(Savemask->getOperand(0)))
745       EMVals.insert(SimpleValue(Savemask->getOperand(0), 0));
746   }
747 
748   auto UnmaskFunc = GenXIntrinsic::getGenXDeclaration(
749       M, GenXIntrinsic::genx_simdcf_unmask, Tys);
750   for (auto ui = UnmaskFunc->use_begin(), ue = UnmaskFunc->use_end(); ui != ue;
751        ++ui) {
752     auto Unmask = dyn_cast<CallInst>(ui->getUser());
753     if (!Unmask)
754       continue;
755     if (FG && (FGA->getGroup(Unmask->getParent()->getParent()) != FG ||
756                ui->getOperandNo() != Unmask->getNumArgOperands()))
757       continue;
758       lowerSimdCF = true;
759     // We have a unmask (in our function group in the case of the late
760     EMVals.insert(SimpleValue(Unmask));
761   }
762   auto RemaskFunc = GenXIntrinsic::getGenXDeclaration(
763       M, GenXIntrinsic::genx_simdcf_remask, Tys);
764   for (auto ui = RemaskFunc->use_begin(), ue = RemaskFunc->use_end(); ui != ue;
765        ++ui) {
766     auto Remask = dyn_cast<CallInst>(ui->getUser());
767     if (!Remask)
768       continue;
769     if (FG && (FGA->getGroup(Remask->getParent()->getParent()) != FG ||
770                ui->getOperandNo() != Remask->getNumArgOperands()))
771       continue;
772       lowerSimdCF = true;
773     // We have a remask (in our function group in the case of the late
774     // pass).  Add the EM value (struct index 0) to EMVals.
775     EMVals.insert(SimpleValue(Remask));
776     // Also add its EM input to EMVals, if not a constant.
777     if (!isa<Constant>(Remask->getOperand(0)))
778       EMVals.insert(SimpleValue(Remask->getOperand(0)));
779   }
780   // delete useless cm_unmask_begin and cm_unmask_end
781   auto UnmaskEF = GenXIntrinsic::getGenXDeclaration(
782          M, GenXIntrinsic::genx_unmask_end);
783   for (auto ui = UnmaskEF->use_begin(), ue = UnmaskEF->use_end(); ui != ue;) {
784     auto u = ui->getUser();
785     ++ui;
786     if (auto UnmaskEnd = dyn_cast<CallInst>(u))
787       UnmaskEnd->eraseFromParent();
788   }
789   auto UnmaskBF = GenXIntrinsic::getGenXDeclaration(
790     M, GenXIntrinsic::genx_unmask_begin);
791   for (auto ui = UnmaskBF->use_begin(), ue = UnmaskBF->use_end(); ui != ue;) {
792     auto u = ui->getUser();
793     ++ui;
794     if (auto UnmaskBeg = dyn_cast<CallInst>(u))
795       UnmaskBeg->eraseFromParent();
796   }
797   // Find related phi nodes and values related by insertvalue/extractvalue/call
798   // using EMVal as a worklist.
799   for (unsigned i = 0; i != EMVals.size(); ++i) {
800     SimpleValue EMVal = EMVals[i];
801     // For this EM value, get the connected values.
802     SmallVector<SimpleValue, 8> ConnectedVals;
803     getConnectedVals(EMVal, RegCategory::EM, /*IncludeOptional=*/true,
804         /*OkJoin=*/nullptr, &ConnectedVals);
805     // Add the connected values to EMVals.
806     for (auto j = ConnectedVals.begin(), je = ConnectedVals.end();
807         j != je; ++j)
808       if (!isa<Constant>(j->getValue()))
809         EMVals.insert(*j);
810   }
811 }
812 
813 /***********************************************************************
814  * gatherRMVals : gather RM values for each join
815  */
gatherRMVals()816 void GenXSimdCFConformance::gatherRMVals()
817 {
818   for (auto ji = EMVals.begin(), je = EMVals.end(); ji != je; ++ji) {
819     auto EMVal = *ji;
820     if (GenXIntrinsic::getGenXIntrinsicID(EMVal.getValue()) != GenXIntrinsic::genx_simdcf_join)
821       continue;
822     auto Join = cast<CallInst>(EMVal.getValue());
823     // We have a join. Gather its web of RM values.
824     auto RMValsEntry = &RMVals[Join];
825     if (!isa<Constant>(Join->getOperand(1)))
826       RMValsEntry->insert(Join->getOperand(1));
827     for (unsigned rvi = 0; rvi != RMValsEntry->size(); ++rvi) {
828       SimpleValue RM = (*RMValsEntry)[rvi];
829       // RM is a value in this join's RM web. Get other values related by phi
830       // nodes and extractvalues and gotos.
831       SmallVector<SimpleValue, 8> ConnectedVals;
832       getConnectedVals(RM, RegCategory::RM, /*IncludeOptional=*/true,
833           Join, &ConnectedVals);
834       for (auto j = ConnectedVals.begin(), je = ConnectedVals.end();
835           j != je; ++j)
836         if (!isa<Constant>(j->getValue()))
837           RMValsEntry->insert(*j);
838     }
839   }
840 }
841 
842 /***********************************************************************
843  * findGotoJoinVal : find goto/join that should be applied at the
844  * specified location
845  *
846  * It uses dominator tree to find the value needed. Category is used to
847  * set proper name for instruction and doesn't affect reg category
848  * that is used in reg alloc. It only shows what we are dealing with.
849  */
findGotoJoinVal(int Cat,BasicBlock * Loc,Instruction * GotoJoinEV,BasicBlockEdge & TrueEdge,BasicBlockEdge & FalseEdge,Value * TrueVal,Value * FalseVal,std::map<BasicBlock *,Value * > & foundVals)850 Value *GenXSimdCFConformance::findGotoJoinVal(int Cat, BasicBlock *Loc, Instruction *GotoJoinEV,
851   BasicBlockEdge &TrueEdge, BasicBlockEdge &FalseEdge, Value *TrueVal, Value *FalseVal, std::map<BasicBlock *, Value *>& foundVals)
852 {
853   IGC_ASSERT(TrueEdge.getStart() == FalseEdge.getStart());
854   IGC_ASSERT(TrueEdge.getEnd() != FalseEdge.getEnd());
855   IGC_ASSERT_MESSAGE((Cat == RegCategory::EM || Cat == RegCategory::PREDICATE), "Handling only EM and Cond!");
856 
857   LLVM_DEBUG(dbgs() << "Entering " << Loc->getName() << "\n");
858 
859   // Check if value was found before
860   auto ResIt = foundVals.find(Loc);
861   if (ResIt != foundVals.end())
862     return ResIt->second;
863 
864   DominatorTree *DomTree = getDomTree(Loc->getParent());
865   if (DomTree->dominates(TrueEdge, Loc)) {
866     LLVM_DEBUG(dbgs() << "Dominated by True Edge\n");
867     foundVals[Loc] = TrueVal;;
868     return TrueVal;
869   }
870   if (DomTree->dominates(FalseEdge, Loc)) {
871     LLVM_DEBUG(dbgs() << "Dominated by False Edge\n");
872     foundVals[Loc] = FalseVal;
873     return FalseVal;
874   }
875 
876   // Need to create phi somewhere.
877   // Try to get IDom. If we found CondEV's BB then we are
878   // already in the final block
879   auto Node = DomTree->getNode(Loc);
880   auto IDom = Node->getIDom();
881   IGC_ASSERT_MESSAGE(IDom, "No IDom found!");
882   BasicBlock *PhiLoc = nullptr;
883   PhiLoc = IDom->getBlock();
884   if (IDom->getBlock() == GotoJoinEV->getParent())
885     PhiLoc = Loc;
886 
887   std::string Name = (Cat == RegCategory::EM) ? "ExecMaskEV" : "CondEV";
888   auto PHI = PHINode::Create(GotoJoinEV->getType(), pred_size(PhiLoc), Name, &PhiLoc->front());
889   foundVals[PhiLoc] = PHI;
890   if (PhiLoc != Loc)
891     foundVals[Loc] = PHI;
892 
893   for (auto pi = pred_begin(PhiLoc), pe = pred_end(PhiLoc); pi != pe; ++pi) {
894     BasicBlock *Pred = *pi;
895     Value *Val = nullptr;
896 
897     // Don't check dominators for def since we are looking for
898     // edges that are located after it
899     if (Pred == TrueEdge.getStart()) {
900       // This happens when we enter def block from join block
901       // w/o any intermediate blocks (actually we expect this
902       // situation to happen always). Check that we came through
903       // true branch.
904       if (Pred->getTerminator()->getSuccessor(0) == PhiLoc) {
905         Val = TrueVal;
906         LLVM_DEBUG(dbgs() << "Usual case\n");
907       } else {
908         // This situation shouldn't happen, but if so, we can handle it
909         Val = FalseVal;
910         LLVM_DEBUG(dbgs() << "Strange case\n");
911       }
912     } else {
913       Val = findGotoJoinVal(Cat, Pred, GotoJoinEV, TrueEdge, FalseEdge, TrueVal, FalseVal, foundVals);
914     }
915 
916     PHI->addIncoming(Val, Pred);
917   }
918 
919   LLVM_DEBUG(dbgs() << "Built PHI for EV:" << *PHI << "\n");
920   return PHI;
921 }
922 
923 /**
924  * collectCondEVUsers : gather Cond EV users
925  *
926  * Bad users: they should not use cond EV.
927  * Correct user: conditional branch CondEV's BB. This
928  * is the only possible conformant user.
929  */
collectCondEVUsers(ExtractValueInst * CondEV,std::vector<Value * > & BadUsers,BranchInst * & CorrectUser)930 void GenXSimdCFConformance::collectCondEVUsers(ExtractValueInst *CondEV, std::vector<Value *> &BadUsers, BranchInst *&CorrectUser)
931 {
932   // Bad users: they should not use cond EV. Make a real value for them
933   // Correct user: conditional branch in this BB
934   for (auto ui = CondEV->use_begin(), ue = CondEV->use_end(); ui != ue; ++ui) {
935     BranchInst *Br = dyn_cast<BranchInst>(ui->getUser());
936 
937     // If cond EV is used by wrong branch, we can simply consider
938     // it as non-baled conditional branch
939     if (!Br || Br->getParent() != CondEV->getParent()) {
940       LLVM_DEBUG(dbgs() << "Found bad CondEV user:\n" << *ui->getUser() << "\n");
941       BadUsers.push_back(ui->getUser());
942     } else if (Br) {
943       IGC_ASSERT_MESSAGE(!CorrectUser, "Found another correct user!");
944       LLVM_DEBUG(dbgs() << "Found correct user:\n" << *Br << "\n");
945       CorrectUser = Br;
946     }
947   }
948 }
949 
950 /**
951  * updateBadCondEVUsers : update bad cond EV users
952  *
953  * It replaces cond EV uses by values that can be
954  * obtained on true and false pathes
955  */
updateBadCondEVUsers(GenXSimdCFConformance::GotoJoinEVs & GotoJoinData,std::vector<Value * > & BadUsers,BasicBlock * TrueSucc,BasicBlock * FalseSucc)956 void GenXSimdCFConformance::updateBadCondEVUsers(GenXSimdCFConformance::GotoJoinEVs &GotoJoinData,
957   std::vector<Value *> &BadUsers, BasicBlock *TrueSucc, BasicBlock *FalseSucc)
958 {
959   ExtractValueInst *CondEV = GotoJoinData.getCondEV();
960   IGC_ASSERT_MESSAGE(CondEV, "Expected valid CondEV!");
961 
962   BasicBlockEdge TrueEdge(CondEV->getParent(), TrueSucc);
963   BasicBlockEdge FalseEdge(CondEV->getParent(), FalseSucc);
964   Constant *TrueVal = Constant::getAllOnesValue(CondEV->getType());
965   Constant *FalseVal = Constant::getNullValue(CondEV->getType());
966 
967   // Update users
968   std::map<BasicBlock *, Value *> FoundCondEV;
969   for (auto bi = BadUsers.begin(), be = BadUsers.end(); bi != be; ++bi) {
970     Instruction *User = cast<Instruction>(*bi);
971     for (unsigned idx = 0, opNum = User->getNumOperands(); idx < opNum; ++idx) {
972       if (CondEV != User->getOperand(idx))
973         continue;
974 
975       User->setOperand(idx, findGotoJoinVal(RegCategory::PREDICATE, User->getParent(), CondEV, TrueEdge, FalseEdge, TrueVal, FalseVal, FoundCondEV));
976     }
977   }
978 }
979 
980 /**
981  * addNewPhisIncomings : add new incomings after split
982  *
983  * It is needed to update phis after turning unconditional
984  * branch into conditional one. True successor is assumed to
985  * be correct join point, but the only thing we know here
986  * is that FalseSucc branches to TrueSucc. Branching Block's
987  * successors are TrueSucc and FalseSucc.
988  */
addNewPhisIncomings(BasicBlock * BranchingBlock,BasicBlock * TrueSucc,BasicBlock * FalseSucc)989 void GenXSimdCFConformance::addNewPhisIncomings(BasicBlock *BranchingBlock, BasicBlock *TrueSucc, BasicBlock *FalseSucc)
990 {
991   for (auto Inst = &TrueSucc->front();
992     auto PN = dyn_cast<PHINode>(Inst);
993     Inst = Inst->getNextNode()) {
994     Value* CurrVal = PN->getIncomingValueForBlock(BranchingBlock);
995     PN->addIncoming(CurrVal, FalseSucc);
996   }
997 }
998 
999 /**
1000  * handleNoCondEVCase : handle case when there is no
1001  * CondEV for goto/join.
1002  *
1003  * It performs split for goto in order to prepare
1004  * goto for possible EM lower. Goto is branch itself
1005  * so such transformation doesn't introduce any
1006  * overhead in case of conformant SIMD CF.
1007  *
1008  * TODO: this transformation can be reverted in case of
1009  * non-conformant SIMD CF if necessary data was saved.
1010  * It is not done now because no non-conformant cases
1011  * were found so far.
1012  */
handleNoCondEVCase(GenXSimdCFConformance::GotoJoinEVs & GotoJoinData)1013 void GenXSimdCFConformance::handleNoCondEVCase(GenXSimdCFConformance::GotoJoinEVs &GotoJoinData)
1014 {
1015   IGC_ASSERT_MESSAGE(!GotoJoinData.getCondEV(), "Unexpected CondEV!");
1016 
1017   // Handle only goto
1018   if (GotoJoinData.isJoin())
1019     return;
1020   auto SplitPoint = GotoJoinData.getSplitPoint();
1021 
1022   // Skip possible goto users
1023   for (;; SplitPoint = SplitPoint->getNextNode()) {
1024     if (SplitPoint->isTerminator())
1025       break;
1026     if (auto CI = dyn_cast<CallInst>(SplitPoint)) {
1027       // We need to perform split before next goto/join to save their conformance
1028       if (GenXIntrinsic::getGenXIntrinsicID(CI) == GenXIntrinsic::genx_simdcf_goto ||
1029         GenXIntrinsic::getGenXIntrinsicID(CI) == GenXIntrinsic::genx_simdcf_join)
1030         break;
1031     }
1032   }
1033 
1034   Value *GotoJoin = GotoJoinData.getGotoJoin();
1035   ExtractValueInst *CondEV = ExtractValueInst::Create(GotoJoin, { 2 }, "missing_extractcond", SplitPoint);
1036   GotoJoinData.setCondEV(CondEV);
1037 
1038   if (auto Br = dyn_cast<BranchInst>(SplitPoint)) {
1039     if (Br->isConditional()) {
1040       // This CF is non-conformant: there should be a join point
1041       // before this branch, but it wasn't found. Skip it.
1042       return;
1043     }
1044     // We are turning unconditional branch into conditional one
1045     BasicBlock *Split = BasicBlock::Create(CondEV->getContext(), "goto_split", CondEV->getParent()->getParent(), Br->getSuccessor(0));
1046     BranchInst::Create(Br->getSuccessor(0), Split);
1047     BranchInst::Create(Br->getSuccessor(0), Split, CondEV, Br);
1048 
1049     // Update phis in TrueSucc
1050     addNewPhisIncomings(CondEV->getParent(), Br->getSuccessor(0), Split);
1051 
1052     Br->eraseFromParent();
1053   } else {
1054     // Split point is in the middle of BB. We assume that there is a join point
1055     // after it.
1056     // TODO: consider adding this check. No such cases were found now.
1057     BasicBlock *TrueSucc = CondEV->getParent()->splitBasicBlock(SplitPoint, "cond_ev_true_split");
1058     CondEV->getParent()->getTerminator()->eraseFromParent();
1059     LLVM_DEBUG(dbgs() << "Created " << TrueSucc->getName() << " to handle missing conditional branch\n");
1060 
1061     // False block: need to create new one
1062     BasicBlock *FalseSucc = BasicBlock::Create(CondEV->getContext(), "cond_ev_false_split", CondEV->getParent()->getParent(),
1063       TrueSucc);
1064     LLVM_DEBUG(dbgs() << "Created " << FalseSucc->getName() << " to handle missing conditional branch\n");
1065 
1066     // Link blocks
1067     BranchInst::Create(TrueSucc, FalseSucc, CondEV, CondEV->getParent());
1068     BranchInst::Create(TrueSucc, FalseSucc);
1069   }
1070 
1071   // CFG changed: update DomTree.
1072   // TODO: there must be workaround to do it in a more optimal way
1073   DominatorTree *domTree = getDomTree(CondEV->getParent()->getParent());
1074   domTree->recalculate(*CondEV->getParent()->getParent());
1075 }
1076 
1077 /**
1078  * handleOptimizedBranchCase : perform split for optimized branch case
1079  *
1080  * TODO: this make sence only in case when the true successor is a
1081  * join block, otherwise it will introduce more overhead due to
1082  * goto/join lowering. Also there should be check that this
1083  * join really uses current EM and RM. This issue is resolved
1084  * at the end of this pass in EM/RM liveness analysis and cannot
1085  * be done easy at this point. For now assume that everything OK
1086  * with it here.
1087  *
1088  * TODO: It is possible to undo this transformation if we store
1089  * all necessery data here. Currently it is not done:
1090  * no non-conformant cases found for now.
1091  *
1092  * Due to earlier transformations we can split BB after the last
1093  * goto/join EV. It will solve issue with join located in this
1094  * basic block. Code movements to sink goto/join will be performed
1095  * further, we don't need to focus on it here.
1096  */
handleOptimizedBranchCase(GenXSimdCFConformance::GotoJoinEVs & GotoJoinData,BasicBlock * & TrueSucc,BasicBlock * & FalseSucc)1097 void GenXSimdCFConformance::handleOptimizedBranchCase(GenXSimdCFConformance::GotoJoinEVs &GotoJoinData, BasicBlock *&TrueSucc, BasicBlock *&FalseSucc)
1098 {
1099   // Look for the first non-goto/join user inst
1100   auto SplitPoint = GotoJoinData.getSplitPoint();
1101 
1102   ExtractValueInst *CondEV = GotoJoinData.getCondEV();
1103   IGC_ASSERT_MESSAGE(CondEV, "Expected valid CondEV!");
1104 
1105   // Split: this is true succ which is join point (at least we assume that)
1106   TrueSucc = CondEV->getParent()->splitBasicBlock(SplitPoint, "cond_ev_true_split");
1107   LLVM_DEBUG(dbgs() << "Created " << TrueSucc->getName() << " to handle missing conditional branch\n");
1108   CondEV->getParent()->getTerminator()->eraseFromParent();
1109   // False block: need to create new one
1110   FalseSucc = BasicBlock::Create(CondEV->getContext(), "cond_ev_false_split", CondEV->getParent()->getParent(),
1111     TrueSucc);
1112   LLVM_DEBUG(dbgs() << "Created " << FalseSucc->getName() << " to handle missing conditional branch\n");
1113   // Link blocks
1114   BranchInst::Create(TrueSucc, FalseSucc, CondEV, CondEV->getParent());
1115   BranchInst::Create(TrueSucc, FalseSucc);
1116 
1117   // Store info for possible optimization
1118   BlocksToOptimize[TrueSucc] =
1119       JoinPointOptData(FalseSucc, GotoJoinData.getEMEV());
1120 
1121   // CFG changed: update DomTree.
1122   // TODO: there must be workaround to do it in a more optimal way
1123   DominatorTree *domTree = getDomTree(CondEV->getParent()->getParent());
1124   domTree->recalculate(*CondEV->getParent()->getParent());
1125 }
1126 
1127 /**
1128  * handleExistingBranchCase : perform actions needed to
1129  * handle case when branch wasn't optimized
1130  *
1131  * It stores True/False successors and adds new BB
1132  * in case when both successors are the same BB.
1133  */
handleExistingBranchCase(GenXSimdCFConformance::GotoJoinEVs & GotoJoinData,BasicBlock * & TrueSucc,BasicBlock * & FalseSucc,BranchInst * ExistingBranch)1134 void GenXSimdCFConformance::handleExistingBranchCase(GenXSimdCFConformance::GotoJoinEVs &GotoJoinData,
1135   BasicBlock *&TrueSucc, BasicBlock *&FalseSucc, BranchInst *ExistingBranch)
1136 {
1137   ExtractValueInst *CondEV = GotoJoinData.getCondEV();
1138   IGC_ASSERT_MESSAGE(CondEV, "Expected valid CondEV!");
1139   IGC_ASSERT_MESSAGE(ExistingBranch->isConditional(), "Expected conditional branch!");
1140 
1141   TrueSucc = ExistingBranch->getSuccessor(0);
1142   FalseSucc = ExistingBranch->getSuccessor(1);
1143 
1144   if (TrueSucc == FalseSucc) {
1145     // We need to simply introduce new BB to get CondEV
1146     FalseSucc = BasicBlock::Create(CondEV->getContext(), "cond_ev_split", CondEV->getParent()->getParent(),
1147       TrueSucc);
1148     BranchInst::Create(TrueSucc, FalseSucc);
1149     ExistingBranch->setSuccessor(1, FalseSucc);
1150 
1151     LLVM_DEBUG(dbgs() << "Created " << FalseSucc->getName() << " to handle always taken CONDITIONAL branch\n");
1152 
1153     // Update phis in TrueSucc
1154     addNewPhisIncomings(CondEV->getParent(), TrueSucc, FalseSucc);
1155 
1156     // CFG changed: update DomTree.
1157     // TODO: there must be workaround to do it in a more optimal way
1158     DominatorTree *domTree = getDomTree(CondEV->getParent()->getParent());
1159     domTree->recalculate(*CondEV->getParent()->getParent());
1160   }
1161 }
1162 
1163 /**
1164  * handleCondValue : perform analysis on Cond EV usage and fix
1165  * it if needed
1166  *
1167  * The basic use case is optimized False Successor. That
1168  * often happens in standard SimplifyCFG pass.
1169  */
handleCondValue(Value * GotoJoin)1170 void GenXSimdCFConformance::handleCondValue(Value *GotoJoin)
1171 {
1172   GotoJoinEVs &GotoJoinData = GotoJoinEVsMap[GotoJoin];
1173   ExtractValueInst *CondEV = GotoJoinData.getCondEV();
1174 
1175   // No cond EV - nothing to handle. Here we create branch for goto
1176   // to make it easier to handle possible bad EM users. Goto is a
1177   // branch itself and it won't introduce any overhead in case
1178   // of conformant SIMD CF
1179   if (!CondEV) {
1180     handleNoCondEVCase(GotoJoinData);
1181     return;
1182   }
1183 
1184   // Collect Cond EV users
1185   std::vector<Value *> BadUsers;
1186   BranchInst *CorrectUser = nullptr;
1187   collectCondEVUsers(CondEV, BadUsers, CorrectUser);
1188 
1189   // Nothing needs to be fixed. However, allow this algorithm to fix
1190   // case with TrueSucc == FalseSucc for goto in order to simplify further
1191   // analysis.
1192   if (BadUsers.empty() && GotoJoinData.isJoin())
1193     return;
1194 
1195   BasicBlock *TrueSucc = nullptr;
1196   BasicBlock *FalseSucc = nullptr;
1197 
1198   if (!CorrectUser) {
1199     // Branch was optimized by some pass. We need to create it again.
1200     handleOptimizedBranchCase(GotoJoinData, TrueSucc, FalseSucc);
1201   } else {
1202     // Branch is still here. Perform actions needed.
1203     handleExistingBranchCase(GotoJoinData, TrueSucc, FalseSucc, CorrectUser);
1204   }
1205 
1206   // Update users
1207   updateBadCondEVUsers(GotoJoinData, BadUsers, TrueSucc, FalseSucc);
1208 }
1209 
1210 /***********************************************************************
1211  * splitGotoJoinBlocks : split Basic Blocks that contains goto/join
1212  *
1213  * This is used to solve problems that can be introduced by some
1214  * standard LLVM passes: one of them is simplified CFG that lead to
1215  * goto/join's condition usage by non-branch instruction. After this
1216  * transformation each BB will contain only one goto or join instruction
1217  * (or none of them), that fact allows us to make further changes simplier.
1218  */
splitGotoJoinBlocks()1219 void GenXSimdCFConformance::splitGotoJoinBlocks() {
1220 
1221   LLVM_DEBUG(dbgs() << "Splitting GotoJoin Blocks\n");
1222 
1223   for (auto &Elem : GotoJoinEVsMap) {
1224 
1225     Value *GotoJoin = Elem.first;
1226     auto &GotoJoinData = Elem.second;
1227 
1228     LLVM_DEBUG(dbgs() << "Trying to split BB for:\n" << *GotoJoin << "\n");
1229 
1230     handleCondValue(GotoJoin);
1231 
1232     if (GotoJoinData.isJoin()) {
1233       auto SplitPoint = GotoJoinData.getSplitPoint();
1234       if (SplitPoint->isTerminator())
1235         continue;
1236       SplitPoint->getParent()->splitBasicBlock(SplitPoint, "split_for_join");
1237       // CFG changed: update DomTree.
1238       // TODO: there must be workaround to do it in a more optimal way
1239       DominatorTree *domTree = getDomTree(SplitPoint->getParent()->getParent());
1240       domTree->recalculate(*SplitPoint->getParent()->getParent());
1241     }
1242   }
1243 
1244   LLVM_DEBUG(dbgs() << "Done splitting\n\n" << *M << "\n\n");
1245 }
1246 
1247 /***********************************************************************
1248  * removeFromEMRMVals : remove a value from EMVals or RMVals
1249  *
1250  * This is used just before erasing a phi node in moveCodeInJoinBlocks.
1251  */
removeFromEMRMVals(Value * V)1252 void GenXSimdCFConformance::removeFromEMRMVals(Value *V)
1253 {
1254   auto VT = dyn_cast<VectorType>(V->getType());
1255   if (!VT || !VT->getElementType()->isIntegerTy(1))
1256     return;
1257   if (EMVals.remove(SimpleValue(V, 0)))
1258     return;
1259   for (auto i = RMVals.begin(), e = RMVals.end(); i != e; ++i) {
1260     auto RMValsEntry = &i->second;
1261     if (RMValsEntry->remove(SimpleValue(V, 0)))
1262       return;
1263   }
1264 }
1265 
1266 /***********************************************************************
1267  * hoistGotoUser : hoist instruction that uses goto's EV and is located
1268  * after it in the same basic block.
1269  *
1270  * Since goto must be at the end of basic block, we have to solve
1271  * this problem somehow. Current approach is to duplicate instruction
1272  * on both paths (true and false) and update uses.
1273  *
1274  * It is always possible to perform such transformation even if there
1275  * is a chain of users: we just can duplicate them all. Since we know
1276  * all values on the true pass, it should be possible to perform full
1277  * calculation in this case. However, it is not done now because it can
1278  * lead to much worse code when SIMD CF is not conformant (we are not
1279  * sure that it is conformant at this point).
1280  */
hoistGotoUser(Instruction * Inst,CallInst * Goto,unsigned operandNo)1281 bool GenXSimdCFConformance::hoistGotoUser(Instruction *Inst, CallInst *Goto, unsigned operandNo)
1282 {
1283   // Find branch for goto
1284   ExtractValueInst *CondEV = GotoJoinEVsMap[Goto].getCondEV();
1285   auto BrIt = std::find_if(CondEV->use_begin(), CondEV->use_end(),
1286     [&Goto](const Use& u) {
1287       auto Br = dyn_cast<BranchInst>(u.getUser());
1288       return (Br && Br->getParent() == Goto->getParent() && Br->isConditional());
1289     });
1290   IGC_ASSERT_MESSAGE(BrIt != CondEV->use_end(), "All gotos should become branching earlier!");
1291 
1292   BranchInst *Br = cast<BranchInst>(BrIt->getUser());
1293   BasicBlock *TrueSucc = Br->getSuccessor(0);
1294   BasicBlock *FalseSucc = Br->getSuccessor(1);
1295 
1296   // Handle FallThrough block with phis.
1297   //
1298   // TODO: it is redundant in some cases. For example, there can be Phi that
1299   // uses bitcasts from EM from two paths. In this case we can use one
1300   // GetEM from Phi with EM. Currently there is no trivial mechanism
1301   // to check for that because in this case Phi arguments are supposed to use
1302   // different Exectution Masks according to DF.
1303   //
1304   // Temporary solution for that is to place a splitter block that branches to
1305   // such bb directly. Examples of that case can be found in local-atomics
1306   // tests in ISPC.
1307   if (isa<PHINode>(&FalseSucc->front())) {
1308     BasicBlock *Splitter = BasicBlock::Create(FalseSucc->getContext(), "phi_fallthrough_splitter", FalseSucc->getParent());
1309     Splitter->moveAfter(Goto->getParent());
1310     BranchInst::Create(FalseSucc, Splitter);
1311     Br->setSuccessor(1, Splitter);
1312     // Update phis
1313     for (auto CurrInst = &FalseSucc->front();
1314          auto PN = dyn_cast<PHINode>(CurrInst);
1315          CurrInst = CurrInst->getNextNode()) {
1316       for (unsigned idx = 0, num = PN->getNumIncomingValues(); idx < num; ++idx) {
1317         if (PN->getIncomingBlock(idx) == Goto->getParent())
1318           PN->setIncomingBlock(idx, Splitter);
1319       }
1320     }
1321     FalseSucc = Splitter;
1322     // CFG changed: update DomTree.
1323     // TODO: there must be workaround to do it in a more optimal way
1324     DominatorTree *domTree = getDomTree(CondEV->getParent()->getParent());
1325     domTree->recalculate(*CondEV->getParent()->getParent());
1326   }
1327 
1328   // Copy instruction and set the value for true block. Place it before goto.
1329   Instruction *TrueVal = Inst->clone();
1330   TrueVal->insertBefore(Goto);
1331   TrueVal->setOperand(operandNo, Constant::getNullValue(Inst->getOperand(operandNo)->getType()));
1332 
1333   // Copy instruction and place it in the false successor. Get EM will be
1334   // created later to handle its goto use.
1335   Instruction *FalseVal = Inst->clone();
1336   FalseVal->insertBefore(FalseSucc->getFirstNonPHI());
1337 
1338   // Handle all users
1339   BasicBlockEdge TrueEdge(Goto->getParent(), TrueSucc);
1340   BasicBlockEdge FalseEdge(Goto->getParent(), FalseSucc);
1341   std::map<BasicBlock *, Value *> foundVals;
1342   std::vector<Value *> newOperands;
1343   for (auto ui = Inst->use_begin(), ue = Inst->use_end(); ui != ue; ++ui) {
1344     auto User = dyn_cast<Instruction>(ui->getUser());
1345     // TODO: it can be solved with duplicated instructions.
1346     // Currently we are not going to duplicate them.
1347     if (User->getParent() == Inst->getParent()) {
1348       TrueVal->eraseFromParent();
1349       FalseVal->eraseFromParent();
1350       return false;
1351     }
1352 
1353     BasicBlock *Loc = User->getParent();
1354     if (auto PN = dyn_cast<PHINode>(User))
1355       Loc = PN->getIncomingBlock(ui->getOperandNo());
1356 
1357     // Store new value
1358     Value *NewOperand = nullptr;
1359     if (Loc == Goto->getParent())
1360       NewOperand = TrueVal;
1361     else
1362       NewOperand = findGotoJoinVal(RegCategory::EM, Loc, Inst, TrueEdge, FalseEdge,
1363         TrueVal, FalseVal, foundVals);
1364 
1365     newOperands.push_back(NewOperand);
1366   }
1367 
1368   // Update uses
1369   unsigned i = 0;
1370   for (auto ui = Inst->use_begin(), ue = Inst->use_end(); ui != ue;) {
1371     auto User = dyn_cast<Instruction>(ui->getUser());
1372     unsigned opNo = ui->getOperandNo();
1373     ++ui;
1374     User->setOperand(opNo, newOperands[i++]);
1375   }
1376 
1377   return true;
1378 }
1379 
1380 /***********************************************************************
1381  * moveCodeInGotoBlocks : move code in goto blocks
1382  *
1383  * A goto and its extractvalues must be at the end of the block. (Actually, if
1384  * the !any result of the goto is used in a conditional branch at the end of
1385  * the block, then the goto being baled into the branch means that it is
1386  * treated as being at the end of the block anyway. The only reason we need to
1387  * sink it here is to ensure that isGotoBlock works.)
1388  *
1389  * This can silently fail to sink a goto, in which case checkGoto will spot that
1390  * the goto is not conformant.
1391  */
moveCodeInGotoBlocks(bool hoistGotoUsers)1392 void GenXSimdCFConformance::moveCodeInGotoBlocks(bool hoistGotoUsers)
1393 {
1394   for (auto gi = EMVals.begin(), ge = EMVals.end(); gi != ge; ++gi) {
1395     auto EMVal = *gi;
1396     if (GenXIntrinsic::getGenXIntrinsicID(EMVal.getValue()) != GenXIntrinsic::genx_simdcf_goto)
1397       continue;
1398     auto Goto = cast<CallInst>(EMVal.getValue());
1399     // We want to sink the goto and its extracts. In fact we hoist any other
1400     // instruction, checking that it does not use the extracts.
1401     // With hoistGotoUsers, we are trying to hoist them, too.
1402     // We are skipping all instructions that use skipped instructions
1403     // in order to save dominance.
1404     std::set<Instruction *> Skipping;
1405     for (Instruction *NextInst = Goto->getNextNode();;) {
1406       auto Inst = NextInst;
1407       if (Inst->isTerminator())
1408         break;
1409       IGC_ASSERT(Inst);
1410       NextInst = Inst->getNextNode();
1411       if (auto Extract = dyn_cast<ExtractValueInst>(Inst))
1412         if (Extract->getOperand(0) == Goto)
1413           continue;
1414       bool Failed = false;
1415       for (unsigned oi = 0, oe = Inst->getNumOperands(); oi != oe; ++oi) {
1416         if (auto I = dyn_cast<Instruction>(Inst->getOperand(oi)))
1417           if (Skipping.count(I)) {
1418             LLVM_DEBUG(dbgs() << "Skipping " << Inst->getName() << " due to use of skipped inst\n");
1419             Skipping.insert(Inst);
1420             Failed = true;
1421             break;
1422           }
1423         if (auto Extract = dyn_cast<ExtractValueInst>(Inst->getOperand(oi)))
1424           if (Extract->getOperand(0) == Goto) {
1425             // This is used after splitting basic blocks.
1426             // To perform this all gotos must be branching since EM
1427             // is changed by goto.
1428             if (hoistGotoUsers && hoistGotoUser(Inst, Goto, oi)) {
1429               continue;
1430             }
1431             LLVM_DEBUG(dbgs() << "moveCodeInGotoBlocks: " << Goto->getName() << " failed\n");
1432             LLVM_DEBUG(dbgs() << "Could not hoist " << Inst->getName() << "\n");
1433             Failed = true;
1434             Skipping.insert(Inst);
1435             break; // Intervening instruction uses extract of goto; abandon
1436           }
1437       }
1438       if (Failed)
1439         continue;
1440       // Hoist the instruction.
1441       Inst->removeFromParent();
1442       Inst->insertBefore(Goto);
1443     }
1444   }
1445 }
1446 
1447 /***********************************************************************
1448  * moveCodeInJoinBlocks : move code in join blocks as necessary
1449  *
1450  * 1. For a join label block (a block that is the JIP of other gotos/joins), a
1451  *    join must come at the start of the block.
1452  *
1453  * 2. For a branching join block (one whose conditional branch condition is the
1454  *    !any result from a join), the join must be at the end of the block.
1455  *
1456  * 3. For a block that has one join with both of the above true, we need to move
1457  *    all other code out of the block.
1458  *
1459  * We achieve this as follows:
1460  *
1461  * a. First handle case 3. For any such block, hoist any other code to the end
1462  *    of its immediate dominator. To allow for the immediate dominator also
1463  *    being a case 3 join, we process blocks in post-order depth first search
1464  *    order, so we visit a block before its dominator. Thus code from a case 3
1465  *    join block eventually gets moved up to its closest dominating block that
1466  *    is not a case 3 join block.
1467  *
1468  *    Because it is more convenient and does not hurt, we also hoist the code
1469  *    before the first join in a block that initially looks like it is case 3,
1470  *    even if it then turns out not to be a case 3 join because it has multiple
1471  *    joins.
1472  *
1473  * b. Then scan all joins handling case 1.
1474  *
1475  * c. No need to handle case 2 here, as it (together with a similar requirement
1476  *    to sink a goto in a branching goto block) is checked in checkConformance
1477  *    and treated as sunk subsequently by virtue of getting baled in to the
1478  *    branch.
1479  *
1480  * This happens in both SIMD CF conformance passes, in case constant loading
1481  * etc sneaks code back into the wrong place in a join block. Any pass after
1482  * the late SIMD CF conformance pass needs to be careful not to sneak code back
1483  * into a join block.
1484  *
1485  * Any failure to do the above is not flagged here, but it will be spotted when
1486  * checking the join for conformance.
1487  *
1488  * moveCodeInGotoBlocks needs to run first, as we rely on its sinking of an
1489  * unconditional branch goto for isBranchingGotoJoinBlock to work.
1490  */
moveCodeInJoinBlocks()1491 void GenXSimdCFConformance::moveCodeInJoinBlocks()
1492 {
1493   // a. Handle case 3 join blocks.
1494   if (!FG) {
1495     // Early pass: iterate all funcs in the module.
1496     for (auto mi = M->begin(), me = M->end(); mi != me; ++mi) {
1497       Function *F = &*mi;
1498       if (!F->empty())
1499         emptyBranchingJoinBlocksInFunc(F);
1500     }
1501   } else {
1502     // Late pass: iterate all funcs in the function group.
1503     for (auto fgi = FG->begin(), fge = FG->end(); fgi != fge; ++fgi) {
1504       Function *F = *fgi;
1505       emptyBranchingJoinBlocksInFunc(F);
1506     }
1507   }
1508   // b. Process all other joins (in fact all joins, but ones successfully
1509   // processed above will not need anything doing).
1510   // Get the joins into a vector first, because the code below modifies EMVals.
1511   SmallVector<CallInst *, 4> Joins;
1512   for (auto ji = EMVals.begin(), je = EMVals.end(); ji != je; ++ji) {
1513     auto EMVal = *ji;
1514     if (GenXIntrinsic::getGenXIntrinsicID(EMVal.getValue()) != GenXIntrinsic::genx_simdcf_join)
1515       continue;
1516     Joins.push_back(cast<CallInst>(EMVal.getValue()));
1517   }
1518   for (auto ji = Joins.begin(), je = Joins.end(); ji != je; ++ji) {
1519     auto Join = *ji;
1520     auto JoinBlock = Join->getParent();
1521     if (GotoJoin::isJoinLabel(JoinBlock, /*SkipCriticalEdgeSplitter=*/true))
1522       hoistJoin(Join);
1523     else {
1524       // The join is in a block that is not a join label. Also check the case
1525       // that there is a predecessor that:
1526       // 1. has one successor; and
1527       // 2. is empty other than phi nodes; and
1528       // 3. is a join label.
1529       // In that case we merge the two blocks, merging phi nodes.
1530       // I have seen this situation arise where LLVM decides to add a loop
1531       // pre-header block.
1532       BasicBlock *PredBlock = nullptr;
1533       for (auto ui = JoinBlock->use_begin(), ue = JoinBlock->use_end(); ui != ue; ++ui) {
1534         auto Br = dyn_cast<BranchInst>(ui->getUser());
1535         if (!Br || Br->isConditional())
1536           continue;
1537         auto BB = Br->getParent();
1538         if (BB->getFirstNonPHIOrDbg() != Br)
1539           continue;
1540         if (GotoJoin::isJoinLabel(BB, /*SkipCriticalEdgeSplitter=*/true)) {
1541           PredBlock = BB;
1542           break;
1543         }
1544       }
1545       if (PredBlock) {
1546         // We have such a predecessor block. First hoist the join in our block.
1547         if (hoistJoin(Join)) {
1548           // Join hoisting succeeded. Now merge the blocks.
1549           LLVM_DEBUG(dbgs() << "moveCodeInJoinBlocks: merging " << PredBlock->getName()
1550               << " into " << JoinBlock->getName() << "\n");
1551           // First adjust the phi nodes to include both blocks' incomings.
1552           for (auto Phi = dyn_cast<PHINode>(&JoinBlock->front()); Phi;
1553               Phi = dyn_cast<PHINode>(Phi->getNextNode())) {
1554             int Idx = Phi->getBasicBlockIndex(PredBlock);
1555             if (Idx >= 0) {
1556               Value *Incoming = Phi->getIncomingValue(Idx);
1557               auto PredPhi = dyn_cast<PHINode>(Incoming);
1558               if (PredPhi && PredPhi->getParent() != PredBlock)
1559                 PredPhi = nullptr;
1560               if (PredPhi) {
1561                 // The incoming in JoinBlock is a phi node in PredBlock. Add its
1562                 // incomings.
1563                 Phi->removeIncomingValue(Idx, /*DeletePHIIfEmpty=*/false);
1564                 for (unsigned oi = 0, oe = PredPhi->getNumIncomingValues();
1565                     oi != oe; ++oi)
1566                   Phi->addIncoming(PredPhi->getIncomingValue(oi),
1567                       PredPhi->getIncomingBlock(oi));
1568               } else {
1569                 // Otherwise, add the predecessors of PredBlock to the phi node
1570                 // in JoinBlock.
1571                 for (auto ui2 = PredBlock->use_begin(),
1572                     ue2 = PredBlock->use_end(); ui2 != ue2; ++ui2) {
1573                   Instruction *Term = dyn_cast<Instruction>(ui2->getUser());
1574                   IGC_ASSERT(Term);
1575                   if (Term->isTerminator()) {
1576                     auto PredPred = Term->getParent();
1577                     if (Idx >= 0) {
1578                       Phi->setIncomingBlock(Idx, PredPred);
1579                       Idx = -1;
1580                     } else
1581                       Phi->addIncoming(Incoming, PredPred);
1582                   }
1583                 }
1584               }
1585             }
1586           }
1587           // Any phi in PredBlock that is not used in a phi in JoinBlock (and
1588           // so still has at least one use after the code above) needs to be
1589           // moved to JoinBlock, with itself added as the extra incomings. The
1590           // incoming blocks to JoinBlock other than PredBlock must be loop
1591           // back edges.
1592           for (;;) {
1593             auto Phi = dyn_cast<PHINode>(&PredBlock->front());
1594             if (!Phi)
1595               break;
1596             if (Phi->use_empty()) {
1597               removeFromEMRMVals(Phi);
1598               Phi->eraseFromParent();
1599               continue;
1600             }
1601             for (auto ui = JoinBlock->use_begin(), ue = JoinBlock->use_end();
1602                 ui != ue; ++ui) {
1603               auto Term = dyn_cast<Instruction>(ui->getUser());
1604               IGC_ASSERT(Term);
1605               if (!Term->isTerminator())
1606                 continue;
1607               auto TermBB = Term->getParent();
1608               if (TermBB == PredBlock)
1609                 continue;
1610               Phi->addIncoming(Phi, TermBB);
1611             }
1612             Phi->removeFromParent();
1613             Phi->insertBefore(&JoinBlock->front());
1614           }
1615           // Adjust branches targeting PredBlock to target JoinBlock instead.
1616           PredBlock->replaceAllUsesWith(JoinBlock);
1617           // Remove PredBlock.
1618           PredBlock->eraseFromParent();
1619         }
1620       }
1621     }
1622   }
1623 }
1624 
1625 /***********************************************************************
1626  * emptyBranchingJoinBlocksInFunc : empty other instructions out of each
1627  *    block in a function that is both a join label and a branching join block
1628  *
1629  * See comment for moveCodeInJoinBlocks above.
1630  */
emptyBranchingJoinBlocksInFunc(Function * F)1631 void GenXSimdCFConformance::emptyBranchingJoinBlocksInFunc(Function *F)
1632 {
1633   for (auto i = po_begin(&F->getEntryBlock()), e = po_end(&F->getEntryBlock());
1634       i != e; ++i) {
1635     BasicBlock *BB = *i;
1636     CallInst *Join = GotoJoin::isBranchingJoinBlock(BB);
1637     if (!Join)
1638       continue;
1639     emptyBranchingJoinBlock(Join);
1640   }
1641 }
1642 
1643 /***********************************************************************
1644  * emptyBranchingJoinBlock : empty instructions other than the join (and its
1645  *      extracts) from this branching join block
1646  */
emptyBranchingJoinBlock(CallInst * Join)1647 void GenXSimdCFConformance::emptyBranchingJoinBlock(CallInst *Join)
1648 {
1649   BasicBlock *BB = Join->getParent();
1650   Instruction *InsertBefore = nullptr;
1651   for (Instruction *NextInst = BB->getFirstNonPHIOrDbg();;) {
1652     auto Inst = NextInst;
1653     if (Inst->isTerminator())
1654       break;
1655     NextInst = Inst->getNextNode();
1656     if (Inst == Join)
1657       continue; // do not hoist the join itself
1658     if (GenXIntrinsic::getGenXIntrinsicID(Inst) == GenXIntrinsic::genx_simdcf_join)
1659       break; // we have encountered another join; there must be more than one
1660     if (auto EV = dyn_cast<ExtractValueInst>(Inst))
1661       if (EV->getOperand(0) == Join)
1662         continue; // do not hoist an extract of the join
1663     // Check that the instruction's operands do not use anything in this block
1664     // (the phi nodes, or the join and extracts being left behind).
1665     for (unsigned oi = 0, oe = Inst->getNumOperands(); oi != oe; ++oi) {
1666       auto Opnd = dyn_cast<Instruction>(Inst->getOperand(oi));
1667       if (Opnd && Opnd->getParent() == BB) {
1668         LLVM_DEBUG(dbgs() << "Failed to empty branching join label for join " << Join->getName() << "\n");
1669         return; // Instruction uses something in this block: abandon.
1670       }
1671     }
1672     if (!InsertBefore) {
1673       // Lazy determination of the insert point. If it is a branching goto/join
1674       // block, insert before the goto/join.
1675       auto DomTree = getDomTree(BB->getParent());
1676       IGC_ASSERT(DomTree);
1677       auto BBNode = DomTree->getNode(BB);
1678       IGC_ASSERT(BBNode);
1679       auto InsertBB = BBNode->getIDom()->getBlock();
1680       InsertBefore = GotoJoin::isBranchingGotoJoinBlock(InsertBB);
1681       if (!InsertBefore)
1682         InsertBefore = InsertBB->getTerminator();
1683     }
1684     // Hoist the instruction.
1685     Inst->removeFromParent();
1686     Inst->insertBefore(InsertBefore);
1687     Modified = true;
1688   }
1689 }
1690 
1691 /***********************************************************************
1692  * getDomTree : get dominator tree for a function
1693  */
getDomTree(Function * F)1694 DominatorTree *GenXSimdCFConformance::getDomTree(Function *F)
1695 {
1696   if (!DTWrapper) {
1697     // In early pass, which is a module pass.
1698     if (!DTs[F]) {
1699       auto DT = new DominatorTree;
1700       DT->recalculate(*F);
1701       DTs[F] = DT;
1702     }
1703     return DTs[F];
1704   }
1705   // In late pass, use the DominatorTreeGroupWrapper.
1706   return DTWrapper->getDomTree(F);
1707 }
1708 
1709 /***********************************************************************
1710  * hoistJoin : hoist a join to the top of its basic block if possible
1711  *
1712  * Return:  whether succeeded
1713  *
1714  * This is used for a join in a block that is a join label, but not a branching
1715  * join block. See comment for emptyJoinBlocks above.
1716  *
1717  * There might be multiple joins in the function, and the one supplied is not
1718  * necessarily the first one. If it is a later one, this function will silently
1719  * fail, which is harmless. If it silently fails for the first join, then we
1720  * end up with a join label block that does not start with a join, which
1721  * checkConformance will spot later on.
1722  *
1723  * This function does return whether it has succeeded, which is used in
1724  * moveCodeInJoinBlocks in the case that it wants to merge a loop pre-header
1725  * back into the join block.
1726  */
hoistJoin(CallInst * Join)1727 bool GenXSimdCFConformance::hoistJoin(CallInst *Join)
1728 {
1729   // This only works if no operand of the join uses one of the instructions
1730   // before it in the block, other than phi nodes.
1731   // However, if we find such an instruction and it is an extractvalue from the
1732   // result of an earlier goto/join in a different block, we can just move it
1733   // to after that goto/join.
1734   for (unsigned oi = 0, oe = Join->getNumArgOperands(); oi != oe; ++oi) {
1735     auto Opnd = dyn_cast<Instruction>(Join->getOperand(oi));
1736     if (!Opnd || isa<PHINode>(Opnd))
1737       continue;
1738     if (Opnd->getParent() == Join->getParent()) {
1739       if (auto EV = dyn_cast<ExtractValueInst>(Opnd)) {
1740         unsigned IID = GenXIntrinsic::getGenXIntrinsicID(EV->getOperand(0));
1741         if (IID == GenXIntrinsic::genx_simdcf_goto
1742             || IID == GenXIntrinsic::genx_simdcf_join) {
1743           auto GotoJoin = cast<CallInst>(EV->getOperand(0));
1744           if (GotoJoin->getParent() != Join->getParent()) {
1745             LLVM_DEBUG(dbgs() << "moving out of join block: " << *EV << "\n");
1746             EV->removeFromParent();
1747             EV->insertBefore(GotoJoin->getNextNode());
1748             continue;
1749           }
1750         }
1751       }
1752       LLVM_DEBUG(dbgs() << "hoistJoin: " << Join->getName() << " failed\n");
1753       return false; // failed -- join uses non-phi instruction before it
1754     }
1755   }
1756   // Hoist the join.
1757   auto BB = Join->getParent();
1758   auto InsertBefore = BB->getFirstNonPHIOrDbg();
1759   if (InsertBefore == Join)
1760     return true; // already at start
1761   Join->removeFromParent();
1762   Join->insertBefore(InsertBefore);
1763   // Such transformation should be performed only for Early Conformance pass
1764   if (!FG)
1765     GotoJoinEVsMap[Join].hoistEVs();
1766   Modified = true;
1767   return true;
1768 }
1769 
1770 /***********************************************************************
1771  * ensureConformance : check for conformance, and lower any non-conformant
1772  *    gotos and joins
1773  */
ensureConformance()1774 void GenXSimdCFConformance::ensureConformance()
1775 {
1776   // Push all EM values onto the stack for checking. Push the joins last, since
1777   // we want to process those before their corresponding gotos, so that
1778   // GotoJoinMap is set for a goto by the time we process a valid goto.
1779   for (auto i = EMVals.begin(), e = EMVals.end(); i != e; ++i) {
1780     auto IID = GenXIntrinsic::getGenXIntrinsicID(i->getValue());
1781     if (IID != GenXIntrinsic::genx_simdcf_join &&
1782         IID != GenXIntrinsic::genx_simdcf_unmask &&
1783         IID != GenXIntrinsic::genx_simdcf_remask)
1784       EMValsStack.insert(*i);
1785   }
1786   for (auto i = EMVals.begin(), e = EMVals.end(); i != e; ++i) {
1787     auto IID = GenXIntrinsic::getGenXIntrinsicID(i->getValue());
1788     if (IID == GenXIntrinsic::genx_simdcf_join)
1789       EMValsStack.insert(*i);
1790   } // Process the stack.
1791   SmallVector<CallInst *, 4> GotosToLower;
1792   SmallVector<CallInst *, 4> JoinsToLower;
1793   for (;;) {
1794     if (!EMValsStack.empty()) {
1795       // Remove and process the top entry on the stack.
1796       auto EMVal = EMValsStack.back();
1797       EMValsStack.pop_back();
1798       if (checkEMVal(EMVal))
1799         continue;
1800       removeBadEMVal(EMVal);
1801       if (!EMVal.getIndex()) {
1802         if (auto CI = dyn_cast<CallInst>(EMVal.getValue())) {
1803           switch (GenXIntrinsic::getGenXIntrinsicID(EMVal.getValue())) {
1804             case GenXIntrinsic::genx_simdcf_goto:
1805               GotosToLower.push_back(CI);
1806               break;
1807             case GenXIntrinsic::genx_simdcf_join:
1808               JoinsToLower.push_back(CI);
1809               break;
1810             default:
1811               break;
1812           }
1813         }
1814       }
1815       continue;
1816     }
1817     // The stack is empty. Check for EM values interfering with each other.
1818     checkEMInterference();
1819     if (EMValsStack.empty()) {
1820       // Stack still empty; we have finished.
1821       break;
1822     }
1823   }
1824 
1825   if (isLatePass()) {
1826     // In the late pass, we are not expecting to have found any non-conformant
1827     // gotos and joins that need lowering. All such gotos and joins should have
1828     // been identified in the early pass, unless passes in between have
1829     // transformed the code in an unexpected way that has made the simd CF
1830     // non-conformant. Give an error here if this has happened.
1831     IGC_ASSERT_EXIT_MESSAGE(GotosToLower.empty(),
1832       "unexpected non-conformant SIMD CF in late SIMD CF conformance pass");
1833     IGC_ASSERT_EXIT_MESSAGE(JoinsToLower.empty(),
1834       "unexpected non-conformant SIMD CF in late SIMD CF conformance pass");
1835   }
1836 
1837   // Lower gotos and joins that turned out to be non-conformant.
1838   for (auto i = GotosToLower.begin(), e = GotosToLower.end(); i != e; ++i)
1839     lowerGoto(*i);
1840   for (auto i = JoinsToLower.begin(), e = JoinsToLower.end(); i != e; ++i)
1841     lowerJoin(*i);
1842 }
1843 
1844 /***********************************************************************
1845  * getEMProducer : perform recurrent check for EM terms.
1846  *
1847  * It goes through all phis and bitcasts (when BitCastAllowed is true)
1848  * and determines whether the EM is correct in DF terms. It doesn't
1849  * check live range interference, but can spot non-conformant usage
1850  * in case when EM from bad instruction is being used.
1851  *
1852  * This approach is used when we need to perform some actions on full
1853  * EM data flow, for example, to insert phis when eliminating redundant
1854  * bitcasts.
1855  *
1856  * All found EM producers are stored in EMProducers and can be used
1857  * later without performing full search.
1858  *
1859  * TODO: currently returns User if it deals with EM. It is done in
1860  * this way as workaround for possible future changes (for example,
1861  * getConnectedVals refactor). The idea of such approach is to be
1862  * able to update info if something changes.
1863  */
getEMProducer(Value * User,std::set<Value * > & Visited,bool BitCastAllowed)1864 Value *GenXSimdCFConformance::getEMProducer(Value *User, std::set<Value *> &Visited, bool BitCastAllowed)
1865 {
1866   LLVM_DEBUG(dbgs() << "Looking for EM producer for value:\n" << *User << "\n");
1867 
1868   if (Visited.count(User)) {
1869     if (dyn_cast<PHINode>(User))
1870       return User;
1871     return nullptr;
1872   }
1873 
1874   // Check for previously found value
1875   auto It = EMProducers.find(User);
1876   if (It != EMProducers.end()) {
1877     LLVM_DEBUG(dbgs() << "Using previously found value:\n" << *It->second << "\n");
1878     return It->second;
1879   }
1880 
1881   if (auto C = dyn_cast<Constant>(User)) {
1882     // All one is considered as EM at entry point
1883     if (C->isAllOnesValue()) {
1884       LLVM_DEBUG(dbgs() << "EMProducer is an AllOne constant\n");
1885       EMProducers[C] = C;
1886       return C;
1887     }
1888   } else if (auto PN = dyn_cast<PHINode>(User)) {
1889     // For phi node, check all its preds. They all must be EMs
1890     Visited.insert(PN);
1891     for (unsigned idx = 0, opNo = PN->getNumOperands(); idx < opNo; ++idx) {
1892       Value *Pred = PN->getOperand(idx);
1893 
1894       if (!getEMProducer(Pred, Visited, BitCastAllowed)) {
1895         LLVM_DEBUG(dbgs() << "!!! Bad phi pred detected for:\n" << *PN << "\n");
1896         EMProducers[PN] = nullptr;
1897         return nullptr;
1898       }
1899     }
1900 
1901     LLVM_DEBUG(dbgs() << "EMProducer is phi itself:\n" << *PN << "\n");
1902     EMProducers[PN] = PN;
1903     return PN;
1904   } else if (auto EVI = dyn_cast<ExtractValueInst>(User)) {
1905     // Extract value can be an EV from goto/join or from callee that
1906     // returned it. For the second case we check that the pred is
1907     // still in EM values since it could be lowered.
1908     CallInst *CI = dyn_cast<CallInst>(EVI->getOperand(0));
1909     if (CI) {
1910       // Goto/join check
1911       if (GenXIntrinsic::getGenXIntrinsicID(CI) == GenXIntrinsic::genx_simdcf_goto ||
1912         GenXIntrinsic::getGenXIntrinsicID(CI) == GenXIntrinsic::genx_simdcf_join) {
1913         LLVM_DEBUG(dbgs() << "Reached goto/join\n");
1914         EMProducers[EVI] = EVI;
1915         return EVI;
1916       }
1917 
1918       // EV from other calls.
1919       if (EMVals.count(SimpleValue(CI, EVI->getIndices()[0]))) {
1920         LLVM_DEBUG(dbgs() << "Value from return\n");
1921         EMProducers[EVI] = EVI;
1922         return EVI;
1923       }
1924     }
1925   } else if (auto Arg = dyn_cast<Argument>(User)){
1926     // For argument we need to ensure that it is still in EM values
1927     // since it could be lowered.
1928     if (EMVals.count(SimpleValue(Arg, Arg->getArgNo()))) {
1929       LLVM_DEBUG(dbgs() << "Input argument\n");
1930       EMProducers[Arg] = Arg;
1931       return Arg;
1932     }
1933   } else if (auto IVI = dyn_cast<InsertValueInst>(User)) {
1934     // Insert value prepares structure for return. Check the
1935     // value that is being inserted
1936     Visited.insert(IVI);
1937     if (auto EMProd = getEMProducer(IVI->getInsertedValueOperand(), Visited, BitCastAllowed)) {
1938       LLVM_DEBUG(dbgs() << "Insert for return\n");
1939       EMProducers[IVI] = EMProd;
1940       return IVI;
1941     }
1942   } else if (BitCastAllowed) {
1943     if (auto BCI = dyn_cast<BitCastInst>(User)) {
1944       // BitCast doesn't produce new EM. Just go through it.
1945       Visited.insert(BCI);
1946       if (auto EMProd = getEMProducer(BCI->getOperand(0), Visited, BitCastAllowed)) {
1947         LLVM_DEBUG(dbgs() << "Bitcast from EM producer\n");
1948         EMProducers[BCI] = EMProd;
1949         return BCI;
1950       }
1951     }
1952   }
1953 
1954   // All other instructions cannot be treated as EM producers
1955   LLVM_DEBUG(dbgs() << "!!! IT IS NOT A EM PRODUCER !!!\n");
1956   return nullptr;
1957 }
1958 
1959 /***********************************************************************
1960  * lowerUnsuitableGetEMs : remove all unsuitable get_em intrinsics.
1961  *
1962  * This intrinsic is unsuitable if:
1963  *   - It uses constant value: it is simply redundant
1964  *   - The EM argument is not actually a EM: this may happen if
1965  *     SIMD CF was non-conformant and this EM was lowered.
1966  */
lowerUnsuitableGetEMs()1967 void GenXSimdCFConformance::lowerUnsuitableGetEMs()
1968 {
1969   Type *I1Ty = Type::getInt1Ty(M->getContext());
1970   Function *GetEMDecl = GenXIntrinsic::getGenXDeclaration(
1971       M, GenXIntrinsic::genx_simdcf_get_em,
1972       {IGCLLVM::FixedVectorType::get(I1Ty, 32)});
1973   std::vector<Instruction *> ToDelete;
1974   for (auto ui = GetEMDecl->use_begin(); ui != GetEMDecl->use_end();) {
1975     std::set<Value *> Visited;
1976     auto GetEM = dyn_cast<Instruction>(ui->getUser());
1977     ++ui;
1978     auto GetEMPred = GetEM->getOperand(0);
1979 
1980     if (GetEM->use_empty()) {
1981       ToDelete.push_back(GetEM);
1982       continue;
1983     }
1984 
1985     // Constants and non-EM values should be used directly
1986     if (dyn_cast<Constant>(GetEMPred) || !getEMProducer(dyn_cast<Instruction>(GetEMPred), Visited)) {
1987       GetEM->replaceAllUsesWith(GetEM->getOperand(0));
1988       ToDelete.push_back(GetEM);
1989     }
1990   }
1991 
1992   for (auto *Inst : ToDelete) {
1993     Inst->eraseFromParent();
1994   }
1995 }
1996 
1997 /***********************************************************************
1998  * lowerAllSimdCF : do NOT check for conformance, and simply lower
1999  * all any gotos, joins, and unmasks
2000  */
lowerAllSimdCF()2001 void GenXSimdCFConformance::lowerAllSimdCF()
2002 {
2003   for (auto i = EMVals.begin(), e = EMVals.end(); i != e; ++i) {
2004     if (auto CI = dyn_cast<CallInst>(i->getValue())) {
2005       auto IID = GenXIntrinsic::getGenXIntrinsicID(i->getValue());
2006       if (IID == GenXIntrinsic::genx_simdcf_join)
2007         lowerJoin(CI);
2008       else if (IID == GenXIntrinsic::genx_simdcf_goto)
2009         lowerGoto(CI);
2010       else if (IID == GenXIntrinsic::genx_simdcf_unmask) {
2011         auto SaveMask = CI->getArgOperand(0);
2012         if (auto CI0 = dyn_cast<CallInst>(SaveMask)) {
2013           IRBuilder<> Builder(CI0);
2014           auto Replace = Builder.CreateBitCast(CI0->getArgOperand(0), CI0->getType());
2015           CI0->replaceAllUsesWith(Replace);
2016           CI0->eraseFromParent();
2017         }
2018         IRBuilder<> Builder(CI);
2019         auto Replace = Builder.CreateBitCast(CI->getArgOperand(1), CI->getType());
2020         CI->replaceAllUsesWith(Replace);
2021         CI->eraseFromParent();
2022       }
2023       else if (IID == GenXIntrinsic::genx_simdcf_remask) {
2024         IRBuilder<> Builder(CI);
2025         auto Replace = Builder.CreateBitCast(CI->getArgOperand(1), CI->getType());
2026         CI->replaceAllUsesWith(Replace);
2027         CI->eraseFromParent();
2028       }
2029     }
2030   }
2031 }
2032 
2033 /***********************************************************************
2034  * checkEMVal : check an EM value for conformance
2035  *
2036  * Return:    true if ok, false if the EM value needs to be removed
2037  */
checkEMVal(SimpleValue EMVal)2038 bool GenXSimdCFConformance::checkEMVal(SimpleValue EMVal)
2039 {
2040   LLVM_DEBUG(dbgs() << "checkEMVal " << *EMVal.getValue() << "#" << EMVal.getIndex() << "\n");
2041   if (!EnableGenXGotoJoin)
2042     return false; // use of goto/join disabled
2043   SmallVector<SimpleValue, 8> ConnectedVals;
2044   // Check connected values. Do not lower bad users in Late Pass because
2045   // current SIMD CF Conformance check approach expects that SIMD CF must
2046   // be OK at this point if it wasn't lowered during Early Pass.
2047   if (!getConnectedVals(EMVal, RegCategory::EM, /*IncludeOptional=*/true,
2048         /*OkJoin=*/nullptr, &ConnectedVals, /*LowerBadUsers=*/!FG)) {
2049     LLVM_DEBUG(dbgs() << "invalid def or uses\n");
2050     return false; // something invalid about the EM value itself
2051   }
2052   // Check that all connected values are EM values.
2053   for (auto i = ConnectedVals.begin(), e = ConnectedVals.end(); i != e; ++i) {
2054     SimpleValue ConnectedVal = *i;
2055     if (auto C = dyn_cast<Constant>(ConnectedVal.getValue())) {
2056       if (!C->isAllOnesValue()) {
2057         LLVM_DEBUG(dbgs() << "ConnectedVal is constant that is not all ones\n");
2058         return false; // uses constant that is not all ones, invalid
2059       }
2060     } else if (!EMVals.count(ConnectedVal)) {
2061       LLVM_DEBUG(dbgs() << "ConnectedVal is not in EMVals\n");
2062       return false; // connected value is not in EMVals
2063     }
2064   }
2065   switch (GenXIntrinsic::getGenXIntrinsicID(EMVal.getValue())) {
2066     case GenXIntrinsic::genx_simdcf_goto:
2067       return checkGoto(EMVal);
2068     case GenXIntrinsic::genx_simdcf_join:
2069       return checkJoin(EMVal);
2070     default:
2071       break;
2072   }
2073   return true;
2074 }
2075 
2076 /***********************************************************************
2077  * checkGotoJoinSunk : check whether a goto/join is sunk to the bottom of
2078  *    its basic block, other than extractvalues from its result
2079  */
checkGotoJoinSunk(CallInst * GotoJoin)2080 static bool checkGotoJoinSunk(CallInst *GotoJoin)
2081 {
2082   for (Instruction *Inst = GotoJoin;;) {
2083     Inst = Inst->getNextNode();
2084     if (Inst->isTerminator()) {
2085       if (!isa<BranchInst>(Inst))
2086         return false;
2087       break;
2088     }
2089     auto EV = dyn_cast<ExtractValueInst>(Inst);
2090     if (!EV || EV->getOperand(0) != GotoJoin)
2091       return false;
2092   }
2093   return true;
2094 }
2095 
2096 /***********************************************************************
2097  * checkGoto : check conformance of an actual goto instruction
2098  */
checkGoto(SimpleValue EMVal)2099 bool GenXSimdCFConformance::checkGoto(SimpleValue EMVal)
2100 {
2101   if (!checkGotoJoin(EMVal))
2102     return false;
2103   // Check that there is a linked join. (We do not need to check here that the
2104   // linked join is an EM value; that happened in checkEMVal due to the join
2105   // being treated as a linked value in getConnectedVals.)
2106   auto Goto = cast<CallInst>(EMVal.getValue());
2107   if (!GotoJoinMap[Goto]) {
2108     LLVM_DEBUG(dbgs() << "checkGoto: no linked join\n");
2109     return false;
2110   }
2111   // Check that the goto is sunk to the end of the block, other than extracts
2112   // from its result, and a branch. moveCodeInGotoBlocks ensures that if
2113   // possible; if that failed, this conformance check fails.
2114   if (!checkGotoJoinSunk(Goto)) {
2115     LLVM_DEBUG(dbgs() << "checkGoto: not sunk\n");
2116     return false;
2117   }
2118   return true;
2119 }
2120 
2121 /***********************************************************************
2122  * checkJoin : check conformance of an actual join instruction
2123  */
checkJoin(SimpleValue EMVal)2124 bool GenXSimdCFConformance::checkJoin(SimpleValue EMVal)
2125 {
2126   if (!checkGotoJoin(EMVal))
2127     return false;
2128   // Check that the join is at the start of the block. emptyJoinBlock should
2129   // have ensured this, unless the code was such that it could not.
2130   auto Join = cast<CallInst>(EMVal.getValue());
2131   if (!GotoJoin::isValidJoin(Join)) {
2132     LLVM_DEBUG(dbgs() << "not valid join\n");
2133     return false;
2134   }
2135   // If the !any result of this join is used in a conditional branch at the
2136   // end, check that the join is sunk to the end of the block, other than
2137   // extracts from its result, and a branch. moveCodeInJoinBlocks ensures that
2138   // if possible; if that failed, this conformance check fails.
2139   if (auto Br = dyn_cast<BranchInst>(Join->getParent()->getTerminator()))
2140     if (Br->isConditional())
2141       if (auto EV = dyn_cast<ExtractValueInst>(Br->getCondition()))
2142         if (EV->getOperand(0) == Join)
2143           if (!checkGotoJoinSunk(Join)) {
2144             LLVM_DEBUG(dbgs() << "checkJoin: not sunk\n");
2145             return false;
2146           }
2147   // Gather the web of RM values.
2148   auto RMValsEntry = &RMVals[Join];
2149   RMValsEntry->clear();
2150   LLVM_DEBUG(dbgs() << "gather web of RM vals for " << *Join << "\n");
2151   if (!isa<Constant>(Join->getOperand(1)))
2152     RMValsEntry->insert(Join->getOperand(1));
2153   for (unsigned rvi = 0; rvi != RMValsEntry->size(); ++rvi) {
2154     SimpleValue RM = (*RMValsEntry)[rvi];
2155     // RM is a value in this join's RM web. Get other values related by phi
2156     // nodes and extractvalues and gotos.
2157     SmallVector<SimpleValue, 8> ConnectedVals;
2158     bool Ok = getConnectedVals(RM, RegCategory::RM, /*IncludeOptional=*/false,
2159         Join, &ConnectedVals);
2160     LLVM_DEBUG(
2161       dbgs() << "getConnectedVals: " << RM.getValue()->getName() << "#" << RM.getIndex() << "\n";
2162       for (auto i = ConnectedVals.begin(), e = ConnectedVals.end(); i != e; ++i)
2163         dbgs() << "   " << i->getValue()->getName() << "#" << i->getIndex() << "\n"
2164     );
2165     if (!Ok) {
2166       LLVM_DEBUG(dbgs() << "illegal RM value in web\n");
2167       return false;
2168     }
2169     for (auto j = ConnectedVals.begin(), je = ConnectedVals.end();
2170         j != je; ++j) {
2171       SimpleValue ConnectedVal = *j;
2172       if (auto C = dyn_cast<Constant>(ConnectedVal.getValue())) {
2173         // A constant in the RM web must be all zeros.
2174         if (!C->isNullValue()) {
2175           LLVM_DEBUG(dbgs() << "non-0 constant in RM web\n");
2176           return false;
2177         }
2178       } else {
2179         // Insert the non-constant value.  If it is a goto with struct index
2180         // other than 1, it is illegal.
2181         if (RMValsEntry->insert(ConnectedVal)) {
2182           LLVM_DEBUG(dbgs() << "New one: " << ConnectedVal.getValue()->getName() << "#" << ConnectedVal.getIndex() << "\n");
2183           switch (GenXIntrinsic::getGenXIntrinsicID(ConnectedVal.getValue())) {
2184             case GenXIntrinsic::genx_simdcf_join:
2185               LLVM_DEBUG(dbgs() << "multiple joins in RM web\n");
2186               return false;
2187             case GenXIntrinsic::genx_simdcf_goto:
2188               if (ConnectedVal.getIndex() != 1/* struct index of RM result */) {
2189                 LLVM_DEBUG(dbgs() << "wrong struct index in goto\n");
2190                 return false;
2191               }
2192               break;
2193             default:
2194               break;
2195           }
2196         }
2197       }
2198     }
2199   }
2200   // Check whether the RM values interfere with each other.
2201   SetVector<Value *> BadDefs;
2202   checkInterference(RMValsEntry, &BadDefs, Join);
2203   if (!BadDefs.empty()) {
2204     LLVM_DEBUG(dbgs() << "RMs interfere\n");
2205     return false;
2206   }
2207   // Set GotoJoinMap for each goto in the RM web.
2208   for (unsigned rvi = 0; rvi != RMValsEntry->size(); ++rvi) {
2209     SimpleValue RM = (*RMValsEntry)[rvi];
2210     if (GenXIntrinsic::getGenXIntrinsicID(RM.getValue()) == GenXIntrinsic::genx_simdcf_goto)
2211       GotoJoinMap[cast<CallInst>(RM.getValue())] = Join;
2212   }
2213   return true;
2214 }
2215 
2216 /***********************************************************************
2217  * getEmptyCriticalEdgeSplitterSuccessor : if BB is an empty critical edge
2218  *    splitter block (one predecessor and one successor), then return the
2219  *    single successor
2220  */
getEmptyCriticalEdgeSplitterSuccessor(BasicBlock * BB)2221 static BasicBlock *getEmptyCriticalEdgeSplitterSuccessor(BasicBlock *BB)
2222 {
2223   if (!BB->hasOneUse())
2224     return nullptr; // not exactly one predecessor
2225   auto Term = dyn_cast<Instruction>(BB->getFirstNonPHIOrDbg());
2226   if (!Term->isTerminator())
2227     return nullptr; // not empty
2228   auto TI = cast<IGCLLVM::TerminatorInst>(Term);
2229   if (TI->getNumSuccessors() != 1)
2230     return nullptr; // not exactly one successor
2231   return TI->getSuccessor(0);
2232 }
2233 
2234 /***********************************************************************
2235  * checkGotoJoin : common code to check conformance of an actual goto or join
2236  *    instruction
2237  */
checkGotoJoin(SimpleValue EMVal)2238 bool GenXSimdCFConformance::checkGotoJoin(SimpleValue EMVal)
2239 {
2240   auto CI = cast<CallInst>(EMVal.getValue());
2241   // If there is an extract of the scalar result of the goto/join, check that
2242   // it is used in the conditional branch at the end of the block.
2243   ExtractValueInst *ExtractScalar = nullptr;
2244   for (auto ui = CI->use_begin(), ue = CI->use_end(); ui != ue; ++ui)
2245     if (auto EV = dyn_cast<ExtractValueInst>(ui->getUser()))
2246       if (!isa<VectorType>(EV->getType())) {
2247         if (ExtractScalar) {
2248           LLVM_DEBUG(dbgs() << "goto/join has more than one extract of its !any result\n");
2249           return false;
2250         }
2251         ExtractScalar = EV;
2252       }
2253   if (ExtractScalar) {
2254     if (!ExtractScalar->hasOneUse()) {
2255       LLVM_DEBUG(dbgs() << "goto/join's !any result does not have exactly one use\n");
2256       return false;
2257     }
2258     auto Br = dyn_cast<BranchInst>(ExtractScalar->use_begin()->getUser());
2259     if (!Br || Br->getParent() != CI->getParent()) {
2260       LLVM_DEBUG(dbgs() << "goto/join's !any result not used in conditional branch in same block\n");
2261       return false;
2262     }
2263     // For a goto/join with a conditional branch, check that the "true"
2264     // successor is a join label. We also tolerate there being an empty
2265     // critical edge splitter block in between; this will get removed in
2266     // setCategories in this pass.
2267     BasicBlock *TrueSucc = Br->getSuccessor(0);
2268     Instruction *First = TrueSucc->getFirstNonPHIOrDbg();
2269     if (GenXIntrinsic::getGenXIntrinsicID(First) != GenXIntrinsic::genx_simdcf_join) {
2270       // "True" successor is not a join label. Check for an empty critical edge
2271       // splitter block in between.
2272       TrueSucc = getEmptyCriticalEdgeSplitterSuccessor(TrueSucc);
2273       if (!TrueSucc) {
2274         LLVM_DEBUG(dbgs() << "goto/join true successor not join label\n");
2275         return false; // Not empty critical edge splitter
2276       }
2277       if (GenXIntrinsic::getGenXIntrinsicID(TrueSucc->getFirstNonPHIOrDbg())
2278           != GenXIntrinsic::genx_simdcf_join) {
2279         LLVM_DEBUG(dbgs() << "goto/join true successor not join label\n");
2280         return false; // Successor is not join label
2281       }
2282     }
2283   }
2284   return true;
2285 }
2286 
2287 /***********************************************************************
2288  * removeBadEMVal : remove a bad EM value
2289  *
2290  * This removes a non-conformant EM value, and pushes any connected EM value
2291  * onto the stack so it gets re-checked for conformance.
2292  */
removeBadEMVal(SimpleValue EMVal)2293 void GenXSimdCFConformance::removeBadEMVal(SimpleValue EMVal)
2294 {
2295   LLVM_DEBUG(
2296     dbgs() << "removeBadEMVal ";
2297     EMVal.print(dbgs());
2298     dbgs() << "\n"
2299   );
2300   // Remove the EM value.
2301   if (!EMVals.remove(EMVal))
2302     return; // was not in EMVals
2303   // Push anything related to it onto the stack for re-checking.
2304   SmallVector<SimpleValue, 8> ConnectedVals;
2305   getConnectedVals(EMVal, RegCategory::EM, /*IncludeOptional=*/true,
2306         /*OkJoin=*/nullptr, &ConnectedVals);
2307   for (auto i = ConnectedVals.begin(), e = ConnectedVals.end(); i != e; ++i) {
2308     SimpleValue ConnectedVal = *i;
2309     if (EMVals.count(ConnectedVal))
2310       EMValsStack.insert(ConnectedVal);
2311   }
2312 }
2313 
2314 /***********************************************************************
2315  * pushValues : push EM struct elements in a value onto EMValsStack
2316  */
pushValues(Value * V)2317 void GenXSimdCFConformance::pushValues(Value *V)
2318 {
2319   for (unsigned si = 0, se = IndexFlattener::getNumElements(V->getType());
2320       si != se; ++si) {
2321     SimpleValue SV(V, si);
2322     if (EMVals.count(SV))
2323       EMValsStack.insert(SV);
2324   }
2325 }
2326 
2327 /***********************************************************************
2328  * checkAllUsesAreSelectOrWrRegion : check that all uses of a value are the
2329  *    condition in select or wrregion or wrpredpredregion (or a predicate
2330  *    in a non-ALU intrinsic)
2331  *
2332  * This is used in getConnectedVals below for the result of a use of an EM
2333  * value in an rdpredregion, or a shufflevector that is a slice so will be
2334  * lowered to rdpredregion.
2335  */
checkAllUsesAreSelectOrWrRegion(Value * V)2336 static bool checkAllUsesAreSelectOrWrRegion(Value *V)
2337 {
2338   for (auto ui2 = V->use_begin(); ui2 != V->use_end(); /*empty*/) {
2339     auto User2 = cast<Instruction>(ui2->getUser());
2340     unsigned OpNum = ui2->getOperandNo();
2341     ++ui2;
2342 
2343     if (isa<SelectInst>(User2))
2344       continue;
2345 
2346     // Matches uses that can be turned into select.
2347     if (auto BI = dyn_cast<BinaryOperator>(User2)) {
2348       auto Opc = BI->getOpcode();
2349       Constant *AllOne = Constant::getAllOnesValue(V->getType());
2350       Constant *AllNul = Constant::getNullValue(V->getType());
2351 
2352       // EM && X -> sel EM X 0
2353       // EM || X -> sel EM 1 X
2354       if (Opc == BinaryOperator::And ||
2355           Opc == BinaryOperator::Or) {
2356         Value *Ops[3] = {V, nullptr, nullptr};
2357         if (Opc == BinaryOperator::And) {
2358           Ops[1] = BI->getOperand(1 - OpNum);
2359           Ops[2] = AllNul;
2360         } else if (Opc == BinaryOperator::Or) {
2361           Ops[1] = AllOne;
2362           Ops[2] = BI->getOperand(1 - OpNum);
2363         }
2364         auto SI = SelectInst::Create(Ops[0], Ops[1], Ops[2], ".revsel", BI, BI);
2365         BI->replaceAllUsesWith(SI);
2366         BI->eraseFromParent();
2367         continue;
2368       }
2369 
2370       // ~EM || X ==> sel EM, X, 1
2371       using namespace PatternMatch;
2372       if (BI->hasOneUse() &&
2373           BI->user_back()->getOpcode() == BinaryOperator::Or &&
2374           match(BI, m_Xor(m_Specific(V), m_Specific(AllOne)))) {
2375         Instruction *OrInst = BI->user_back();
2376         Value *Op = OrInst->getOperand(0) != BI ? OrInst->getOperand(0)
2377                                                 : OrInst->getOperand(1);
2378         auto SI = SelectInst::Create(V, Op, AllOne, ".revsel", OrInst, OrInst);
2379         OrInst->replaceAllUsesWith(SI);
2380         OrInst->eraseFromParent();
2381         BI->eraseFromParent();
2382         continue;
2383       }
2384 
2385       // ~EM && X ==> sel EM, 0, X
2386       using namespace PatternMatch;
2387       if (BI->hasOneUse() &&
2388           BI->user_back()->getOpcode() == BinaryOperator::And &&
2389           match(BI, m_Xor(m_Specific(V), m_Specific(AllOne)))) {
2390         Instruction *AndInst = BI->user_back();
2391         Value *Op = AndInst->getOperand(0) != BI ? AndInst->getOperand(0)
2392                                                  : AndInst->getOperand(1);
2393         auto SI = SelectInst::Create(V, AllNul, Op, ".revsel", AndInst, AndInst);
2394         AndInst->replaceAllUsesWith(SI);
2395         AndInst->eraseFromParent();
2396         BI->eraseFromParent();
2397         continue;
2398       }
2399     } else if (auto CI = dyn_cast<CastInst>(User2)) {
2400       // Turn zext/sext to select.
2401       if (CI->getOpcode() == Instruction::CastOps::ZExt ||
2402           CI->getOpcode() == Instruction::CastOps::SExt) {
2403         unsigned NElts =
2404             cast<IGCLLVM::FixedVectorType>(V->getType())->getNumElements();
2405         unsigned NBits = CI->getType()->getScalarSizeInBits();
2406         int Val = (CI->getOpcode() == Instruction::CastOps::ZExt) ? 1 : -1;
2407         APInt One(NBits, Val);
2408         Constant *LHS = ConstantVector::getSplat(
2409             IGCLLVM::getElementCount(NElts),
2410             ConstantInt::get(CI->getType()->getScalarType(), One));
2411         Constant *AllNul = Constant::getNullValue(CI->getType());
2412         auto SI = SelectInst::Create(V, LHS, AllNul, ".revsel", CI, CI);
2413         CI->replaceAllUsesWith(SI);
2414         CI->eraseFromParent();
2415         continue;
2416       }
2417     }
2418 
2419     unsigned IID = GenXIntrinsic::getAnyIntrinsicID(User2);
2420     if (GenXIntrinsic::isWrRegion(IID))
2421       continue;
2422     if (IID == GenXIntrinsic::genx_wrpredpredregion
2423         && OpNum == cast<CallInst>(User2)->getNumArgOperands() - 1)
2424       continue;
2425     if (GenXIntrinsic::isAnyNonTrivialIntrinsic(IID)
2426         && !cast<CallInst>(User2)->doesNotAccessMemory())
2427       continue;
2428     return false;
2429   }
2430   return true;
2431 }
2432 
2433 /***********************************************************************
2434  * getConnectedVals : for a SimpleValue, get other SimpleValues connected to
2435  *    it through phi nodes, insertvalue, extractvalue, goto/join, and maybe
2436  *    args and return values
2437  *
2438  * Enter:   Val = SimpleValue to start at
2439  *          Cat = RegCategory::EM to do EM connections
2440  *                RegCategory::RM to do RM connections
2441  *          IncludeOptional = for EM connections, include optional connections
2442  *                where Val is a function arg and it is connected to call args,
2443  *                and where Val is the operand to return and it is connected to
2444  *                the returned value at call sites
2445  *          OkJoin = for RM connections, error if a use in a join other than
2446  *                this one is found
2447  *          ConnectedVals = vector to store connected values in
2448  *
2449  * Return:  true if ok, false if def or some use is not suitable for EM/RM
2450  *
2451  * The provided value must be non-constant, but the returned connected values
2452  * may include constants. Duplicates may be stored in ConnectedVals.
2453  *
2454  * This function is used in three different ways by its callers:
2455  *
2456  * 1. to gather a web of putative EM values or RM values starting at goto/join
2457  *    instructions;
2458  *
2459  * 2. to test whether a putative EM/RM value is valid by whether its connected
2460  *    neighbors are EM/RM values;
2461  *
2462  * 3. when removing a value from the EM/RM values list, to find its connected
2463  *    neighbors to re-run step 2 on each of them.
2464  *
2465  * TODO: some refactoring should be performed here due to quite big
2466  *       CF with many different actions. Also some of these actions
2467  *       are repeated in different situations.
2468  */
getConnectedVals(SimpleValue Val,int Cat,bool IncludeOptional,CallInst * OkJoin,SmallVectorImpl<SimpleValue> * ConnectedVals,bool LowerBadUsers)2469 bool GenXSimdCFConformance::getConnectedVals(SimpleValue Val, int Cat,
2470     bool IncludeOptional, CallInst *OkJoin,
2471     SmallVectorImpl<SimpleValue> *ConnectedVals, bool LowerBadUsers)
2472 {
2473   // Check the def first.
2474   if (auto Arg = dyn_cast<Argument>(Val.getValue())) {
2475     if (Cat != RegCategory::EM)
2476       return false; // can't have RM argument
2477     // Connected to some return value. There is a problem here in that it might
2478     // find another predicate return value that is nothing to do with SIMD CF,
2479     // and thus stop SIMD CF being optimized. But passing a predicate in and
2480     // out of a function is rare outside of SIMD CF, so we do not worry about
2481     // that.
2482     // It is possible that EM was optimized from ret. In this case the ret type
2483     // is void. Allow such situation.
2484     Function *F = Arg->getParent();
2485     unsigned RetIdx = 0;
2486     auto RetTy = F->getReturnType();
2487     auto ValTy = IndexFlattener::getElementType(
2488         Val.getValue()->getType(), Val.getIndex());
2489     if (auto ST = dyn_cast<StructType>(RetTy)) {
2490       for (unsigned End = IndexFlattener::getNumElements(ST); ; ++RetIdx) {
2491         if (RetIdx == End)
2492           return false; // no predicate ret value found
2493         if (IndexFlattener::getElementType(ST, RetIdx) == ValTy)
2494           break;
2495       }
2496     } else if (RetTy != ValTy && !RetTy->isVoidTy())
2497       return false; // no predicate ret value found
2498     if (!RetTy->isVoidTy())
2499       for (auto fi = F->begin(), fe = F->end(); fi != fe; ++fi)
2500         if (auto Ret = dyn_cast<ReturnInst>(fi->getTerminator()))
2501           ConnectedVals->push_back(SimpleValue(Ret->getOperand(0), RetIdx));
2502     if (IncludeOptional) {
2503       // With IncludeOptional, also add the corresponding arg at each call
2504       // site.
2505       for (auto *U : F->users())
2506         if (auto *CI = checkFunctionCall(U, F))
2507           ConnectedVals->push_back(
2508               SimpleValue(CI->getArgOperand(Arg->getArgNo()), Val.getIndex()));
2509     }
2510   } else if (auto Phi = dyn_cast<PHINode>(Val.getValue())) {
2511     // phi: add (the corresponding struct element of) each incoming
2512     for (unsigned oi = 0, oe = Phi->getNumIncomingValues(); oi != oe; ++oi)
2513       ConnectedVals->push_back(
2514           SimpleValue(Phi->getIncomingValue(oi), Val.getIndex()));
2515   } else if (auto EVI = dyn_cast<ExtractValueInst>(Val.getValue())) {
2516     // extractvalue: add the appropriate struct element of the input
2517     ConnectedVals->push_back(SimpleValue(EVI->getOperand(0),
2518             Val.getIndex() + IndexFlattener::flatten(
2519               cast<StructType>(EVI->getOperand(0)->getType()),
2520               EVI->getIndices())));
2521   } else if (auto IVI = dyn_cast<InsertValueInst>(Val.getValue())) {
2522     // insertvalue: add the appropriate struct element in either the
2523     // aggregate input or the value to insert input
2524     unsigned InsertedIndex = Val.getIndex() - IndexFlattener::flatten(
2525         cast<StructType>(IVI->getType()), IVI->getIndices());
2526     unsigned NumElements = IndexFlattener::getNumElements(
2527         IVI->getOperand(1)->getType());
2528     SimpleValue SV;
2529     if (InsertedIndex < NumElements)
2530       SV = SimpleValue(IVI->getOperand(1), InsertedIndex);
2531     else
2532       SV = SimpleValue(IVI->getOperand(0), Val.getIndex());
2533     ConnectedVals->push_back(SV);
2534   } else if (auto SVI = dyn_cast<ShuffleVectorInst>(Val.getValue())) {
2535     // shufflevector: add the EM use
2536     ConnectedVals->push_back(SimpleValue(SVI->getOperand(0), 0));
2537   } else if (auto CI = dyn_cast<CallInst>(Val.getValue())) {
2538     switch (GenXIntrinsic::getAnyIntrinsicID(CI)) {
2539       case GenXIntrinsic::genx_simdcf_goto:
2540         // goto: invalid unless it is the EM/RM result of goto as applicable
2541         if (Val.getIndex() != (Cat == RegCategory::EM ? 0U : 1U))
2542           return false;
2543         // Add the corresponding input.
2544         ConnectedVals->push_back(CI->getOperand(Val.getIndex()));
2545         // If doing EM connections, add the corresponding join. This does
2546         // nothing if checkJoin has not yet run for the corresponding join,
2547         // since GotoJoinMap has not yet been set up for our goto. We tolerate
2548         // that situation; if the goto really has no linked join, that is
2549         // picked up later in checkGoto.
2550         if (Cat == RegCategory::EM)
2551           if (auto Join = GotoJoinMap[cast<CallInst>(Val.getValue())])
2552             ConnectedVals->push_back(
2553                 SimpleValue(Join, 0/* struct idx of EM result */));
2554         break;
2555       case GenXIntrinsic::genx_simdcf_join: {
2556         // join: invalid unless it is the EM result
2557         if (Val.getIndex() || Cat != RegCategory::EM)
2558           return false;
2559         // Add the corresponding input.
2560         ConnectedVals->push_back(CI->getOperand(Val.getIndex()));
2561         // Add the corresponding gotos. This does nothing if checkJoin has not
2562         // yet run for this join, since RMVals has not yet been set up for it.
2563         // That is OK, because adding the corresponding gotos here is required
2564         // only when we are called by removeBadEMVal to remove the join, so the
2565         // gotos get re-checked and found to be invalid.
2566         auto RMValsEntry = &RMVals[cast<CallInst>(Val.getValue())];
2567         for (auto i = RMValsEntry->begin(), e = RMValsEntry->end(); i != e; ++i)
2568           if (GenXIntrinsic::getGenXIntrinsicID(i->getValue()) == GenXIntrinsic::genx_simdcf_goto)
2569             ConnectedVals->push_back(
2570                 SimpleValue(i->getValue(), 0/* struct idx of EM result */));
2571         break;
2572       }
2573       case GenXIntrinsic::genx_simdcf_savemask:
2574       case GenXIntrinsic::genx_simdcf_remask:
2575       case GenXIntrinsic::genx_simdcf_get_em:
2576         // Add the corresponding input.
2577         ConnectedVals->push_back(CI->getOperand(0));
2578         return true;
2579       case GenXIntrinsic::genx_constantpred:
2580         // constantpred: add the constant. Don't add any other uses of it,
2581         // because it might be commoned up with other RM webs.
2582         ConnectedVals->push_back(CI->getOperand(0));
2583         return true;
2584       case GenXIntrinsic::not_any_intrinsic: {
2585         // Value returned from a call.
2586         if (Cat != RegCategory::EM)
2587           return false; // invalid for RM
2588         // Add the corresponding value at each return in the called function.
2589         auto CalledFunc = CI->getCalledFunction();
2590         for (auto fi = CalledFunc->begin(), fe = CalledFunc->end();
2591             fi != fe; ++fi)
2592           if (auto Ret = dyn_cast<ReturnInst>(fi->getTerminator()))
2593             if (!Ret->getType()->isVoidTy())
2594               ConnectedVals->push_back(
2595                   SimpleValue(Ret->getOperand(0), Val.getIndex()));
2596         // Connected to some call arg. There is a problem here in that it might
2597         // find another predicate arg that is nothing to do with SIMD CF, and
2598         // thus stop SIMD CF being optimized. But passing a predicate in and
2599         // out of a function is rare outside of SIMD CF, so we do not worry
2600         // about that.
2601         auto ValTy = IndexFlattener::getElementType(
2602             Val.getType(), Val.getIndex());
2603         for (unsigned Idx = 0, End = CI->getNumArgOperands(); ; ++Idx) {
2604           if (Idx == End)
2605             return false; // no corresponding call arg found
2606           if (CI->getArgOperand(Idx)->getType() == ValTy) {
2607             ConnectedVals->push_back(SimpleValue(CI->getArgOperand(Idx), 0));
2608             break;
2609           }
2610         }
2611         break;
2612       }
2613       default:
2614         return false; // unexpected call as def
2615     }
2616   } else
2617     return false; // unexpected instruction as def
2618   // Check the uses.
2619   std::vector<SimpleValue> UsersToLower;
2620   for (auto ui = Val.getValue()->use_begin(),
2621       ue = Val.getValue()->use_end(); ui != ue; ++ui) {
2622     auto User = cast<Instruction>(ui->getUser());
2623     if (auto Phi = dyn_cast<PHINode>(User)) {
2624       // Use in phi node. Add the phi result.
2625       ConnectedVals->push_back(SimpleValue(Phi, Val.getIndex()));
2626       continue;
2627     }
2628     if (auto EVI = dyn_cast<ExtractValueInst>(User)) {
2629       // Use in extractvalue.
2630       // If extracting the right index, add the result.
2631       unsigned StartIndex = IndexFlattener::flatten(
2632           cast<StructType>(EVI->getOperand(0)->getType()), EVI->getIndices());
2633       unsigned NumIndices = IndexFlattener::getNumElements(EVI->getType());
2634       unsigned ExtractedIndex = Val.getIndex() - StartIndex;
2635       if (ExtractedIndex < NumIndices)
2636         ConnectedVals->push_back(SimpleValue(EVI, ExtractedIndex));
2637       continue;
2638     }
2639     if (auto IVI = dyn_cast<InsertValueInst>(User)) {
2640       // Use in insertvalue. Could be either the aggregate input or the value
2641       // to insert.
2642       unsigned StartIndex = IndexFlattener::flatten(
2643           cast<StructType>(IVI->getType()), IVI->getIndices());
2644       unsigned NumIndices = IndexFlattener::getNumElements(
2645           IVI->getOperand(1)->getType());
2646       if (!ui->getOperandNo()) {
2647         // Use in insertvalue as the aggregate input. Add the corresponding
2648         // element in the result, as long as it is not overwritten by the
2649         // insertvalue.
2650         if (Val.getIndex() - StartIndex >= NumIndices)
2651           ConnectedVals->push_back(SimpleValue(IVI, Val.getIndex()));
2652       } else {
2653         // Use in insertvalue as the value to insert. Add the corresponding
2654         // element in the result.
2655         ConnectedVals->push_back(SimpleValue(IVI, StartIndex + Val.getIndex()));
2656       }
2657       continue;
2658     }
2659     if (isa<ReturnInst>(User)) {
2660       // Use in a return.
2661       if (Cat != RegCategory::EM)
2662         return false; // invalid for RM
2663       // Connected to some function arg. There is a problem here in that it might
2664       // find another predicate arg that is nothing to do with SIMD CF, and
2665       // thus stop SIMD CF being optimized. But passing a predicate in and
2666       // out of a function is rare outside of SIMD CF, so we do not worry
2667       // about that.
2668       auto ValTy = IndexFlattener::getElementType(
2669           Val.getType(), Val.getIndex());
2670       auto F = User->getParent()->getParent();
2671       bool Lower = false;
2672       for (auto ai = F->arg_begin(), ae = F->arg_end(); ; ++ai) {
2673         if (ai == ae) {
2674           // no arg of the right type found
2675           Lower = true;
2676           UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
2677           break;
2678         }
2679         auto Arg = &*ai;
2680         if (Arg->getType() == ValTy) {
2681           ConnectedVals->push_back(SimpleValue(Arg, 0));
2682           break;
2683         }
2684       }
2685       if (IncludeOptional && !Lower) {
2686         // With IncludeOptional, also add the values connected by being the
2687         // return value at each call site.
2688         for (auto *U : F->users())
2689           if (auto *CI = checkFunctionCall(U, F))
2690             ConnectedVals->push_back(SimpleValue(CI, Val.getIndex()));
2691       }
2692       continue;
2693     }
2694     if (isa<SelectInst>(User)) {
2695       // A use in a select is allowed only for EM used as the condition.
2696       if (Cat != RegCategory::EM || ui->getOperandNo() != 0)
2697         UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
2698       continue;
2699     }
2700     if (auto SVI = dyn_cast<ShuffleVectorInst>(User)) {
2701       if (!ShuffleVectorAnalyzer(SVI).isReplicatedSlice()) {
2702         UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
2703         continue;
2704       }
2705       // This is a shufflevector that is a replicated slice, so it can be
2706       // lowered to rdpredregion or baled with instruction with channels.
2707       // (We only see this in the early pass; GenXLowering has
2708       // turned it into rdpredregion by the late pass.) Check that all its uses
2709       // are select or wrregion.
2710       if (!checkAllUsesAreSelectOrWrRegion(SVI)) {
2711         UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
2712         continue;
2713       }
2714       // Shufflevector produces EM for value baled inst, so this is a (almost) real EM def:
2715       // add it here to perform correct EM interference check
2716       ConnectedVals->push_back(SimpleValue(SVI, ui->getOperandNo()));
2717       continue;
2718     }
2719     if (auto CI = dyn_cast<CallInst>(User)) {
2720       switch (GenXIntrinsic::getAnyIntrinsicID(CI)) {
2721         case GenXIntrinsic::genx_simdcf_get_em:
2722           IGC_ASSERT(Cat == RegCategory::EM);
2723           // Skip it if the category is right. This
2724           // intrinsic doesn't produce EM
2725           break;
2726         case GenXIntrinsic::genx_simdcf_unmask:
2727         case GenXIntrinsic::genx_simdcf_remask:
2728           IGC_ASSERT(Cat == RegCategory::EM);
2729           ConnectedVals->push_back(SimpleValue(CI, 0));
2730           break;
2731         case GenXIntrinsic::genx_simdcf_goto:
2732           // use in goto: valid only if arg 0 (EM) or 1 (RM)
2733           if (ui->getOperandNo() != (Cat == RegCategory::EM ? 0U : 1U))
2734             return false;
2735           // Add corresponding result.
2736           ConnectedVals->push_back(SimpleValue(CI, ui->getOperandNo()));
2737           break;
2738         case GenXIntrinsic::genx_simdcf_join:
2739           // use in join: valid only if arg 0 (EM) or 1 (RM)
2740           if (ui->getOperandNo() != (Cat == RegCategory::EM ? 0U : 1U))
2741             return false;
2742           // If EM, add corresponding result.
2743           if (Cat == RegCategory::EM)
2744             ConnectedVals->push_back(SimpleValue(CI, 0));
2745           else if (OkJoin && OkJoin != CI) {
2746             // RM value used in a join other than OkJoin. That is illegal, as we
2747             // can only have one join per RM web.
2748             LLVM_DEBUG(dbgs() << "getConnectedVals: found illegal join: " << CI->getName() << "\n");
2749             return false;
2750           }
2751           break;
2752         case GenXIntrinsic::genx_wrregionf:
2753         case GenXIntrinsic::genx_wrregioni:
2754           break; // Use as wrregion predicate is allowed.
2755         case GenXIntrinsic::genx_rdpredregion:
2756           // We only see rdpredregion in the late pass; in the early pass it is
2757           // still a shufflevector.  Check that all its uses are select or
2758           // wrregion.
2759           if (!checkAllUsesAreSelectOrWrRegion(CI))
2760             UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
2761           break;
2762         case GenXIntrinsic::genx_wrpredpredregion:
2763           // Use in wrpredpredregion allowed as the last arg.
2764           if (ui->getOperandNo() + 1 != CI->getNumArgOperands())
2765             UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
2766           break;
2767         default:
2768           // Allowed as an predicate in a non-ALU intrinsic.
2769           if (CI->getCalledFunction()->doesNotAccessMemory())
2770             UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
2771           break;
2772         case GenXIntrinsic::not_any_intrinsic: {
2773           // Use in subroutine call. Add the corresponding function arg.
2774           Function *CalledFunc = CI->getCalledFunction();
2775           IGC_ASSERT(CalledFunc);
2776           auto ai = CalledFunc->arg_begin();
2777           for (unsigned Count = ui->getOperandNo(); Count; --Count, ++ai)
2778             ;
2779           Argument *Arg = &*ai;
2780           ConnectedVals->push_back(SimpleValue(Arg, Val.getIndex()));
2781           // Connected to some return value from the call. There is a problem
2782           // here in that it might find another predicate return value that is
2783           // nothing to do with SIMD CF, and thus stop SIMD CF being optimized.
2784           // But passing a predicate in and out of a function is rare outside
2785           // of SIMD CF, so we do not worry about that.
2786           unsigned RetIdx = 0;
2787           auto ValTy = IndexFlattener::getElementType(
2788               Val.getValue()->getType(), Val.getIndex());
2789           if (auto ST = dyn_cast<StructType>(CI->getType())) {
2790             for (unsigned End = IndexFlattener::getNumElements(ST); ; ++RetIdx) {
2791               if (RetIdx == End)
2792                 UsersToLower.push_back(SimpleValue(User, ui->getOperandNo())); // no predicate ret value found
2793               if (IndexFlattener::getElementType(ST, RetIdx) == ValTy) {
2794                 ConnectedVals->push_back(SimpleValue(CI, RetIdx));
2795                 break;
2796               }
2797             }
2798           } else if (CI->getType() == ValTy)
2799             ConnectedVals->push_back(SimpleValue(CI, 0));
2800           else if (!CI->getType()->isVoidTy())
2801             UsersToLower.push_back(SimpleValue(User, ui->getOperandNo())); // no predicate ret value found
2802           break;
2803         }
2804       }
2805       continue;
2806     }
2807     UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
2808   }
2809 
2810   if (LowerBadUsers) {
2811     SetVector<Value *> ToRemove;
2812     for (auto BadUser : UsersToLower) {
2813       replaceUseWithLoweredEM(dyn_cast<Instruction>(BadUser.getValue()),
2814           BadUser.getIndex(), ToRemove);
2815     }
2816     for (auto Inst : ToRemove) {
2817       removeFromEMRMVals(Inst);
2818     }
2819   } else {
2820     if (!UsersToLower.empty())
2821       return false;
2822   }
2823 
2824   return true;
2825 }
2826 
2827 // check if this is an EM value or part of an EM value.
isEM(Value * V)2828 static bool isEM(Value *V) {
2829   if (auto SI = dyn_cast<ShuffleVectorInst>(V))
2830     return isEM(SI->getOperand(0)) || isEM(SI->getOperand(1));
2831   return GotoJoin::isEMValue(V);
2832 }
2833 
2834 // canonicalizeEM : canonicalize EM uses so that EM uses will not
2835 // stop SIMD-CF conformance.
canonicalizeEM()2836 void GenXSimdCFConformance::canonicalizeEM() {
2837   using namespace PatternMatch;
2838   std::vector<Instruction *> DeadInstructions;
2839 
2840   for (auto &F : M->getFunctionList())
2841     for (auto &BB : F.getBasicBlockList()) {
2842       for (Instruction *Inst = BB.getTerminator(); Inst;) {
2843         // select(C0&C1, a, b) -> select(C0, select(C1, a, b), b)
2844         // select(C0|C1, a, b) -> select(C0, a, select(C1, a, b))
2845         Value *C0, *C1, *A, *B;
2846         if (match(Inst, m_Select(m_BinOp(m_Value(C0), m_Value(C1)), m_Value(A),
2847                                  m_Value(B)))) {
2848           bool C1IsEM = isEM(C1);
2849           if (C1IsEM || isEM(C0)) {
2850             Value *Cond = Inst->getOperand(0);
2851             if (Cond->getType()->isVectorTy()) {
2852               BinaryOperator *BO = cast<BinaryOperator>(Cond);
2853               // Set Inst as insert point in order to save dominance
2854               IRBuilder<> Builder(Inst);
2855               if (C1IsEM)
2856                 std::swap(C0, C1);
2857               if (BO->getOpcode() == BinaryOperator::And) {
2858                 Value *V = Builder.CreateSelect(C1, A, B);
2859                 V = Builder.CreateSelect(C0, V, B);
2860                 Inst->replaceAllUsesWith(V);
2861                 DeadInstructions.push_back(Inst);
2862               } else if (BO->getOpcode() == BinaryOperator::Or) {
2863                 Value *V = Builder.CreateSelect(C1, A, B);
2864                 V = Builder.CreateSelect(C0, A, V);
2865                 Inst->replaceAllUsesWith(V);
2866                 DeadInstructions.push_back(Inst);
2867               }
2868             }
2869           }
2870         }
2871 
2872         Inst = (Inst == &BB.front()) ? nullptr : Inst->getPrevNode();
2873       }
2874     }
2875 
2876   for (Instruction *I : DeadInstructions)
2877     RecursivelyDeleteTriviallyDeadInstructions(I);
2878 
2879   // Collect data for gotos/joins EVs
2880   handleEVs();
2881   // Resolve bitcast chains so they don't break conformance
2882   resolveBitCastChains();
2883 }
2884 
2885 /***********************************************************************
2886  * handleEVs : collect goto/join EVs and perform some transformations
2887  * on them.
2888  *
2889  * All transformations are done in GotoJoinEVs constructor.
2890  */
handleEVs()2891 void GenXSimdCFConformance::handleEVs()
2892 {
2893   // Collect gotos/joins
2894   gatherGotoJoinEMVals(false);
2895   for (auto val : EMVals) {
2896     Value *GotoJoin = val.getValue();
2897     IGC_ASSERT(testIsGotoJoin(GotoJoin));
2898     GotoJoinEVsMap[GotoJoin] = GotoJoinEVs(GotoJoin);
2899   }
2900   EMVals.clear();
2901 }
2902 
2903 /***********************************************************************
2904  * eliminateBitCastPreds : perform bitcast elimination on EM DF
2905  *
2906  * GetEMPred should be called earlier to check if Val is actually
2907  * a EM producer.
2908  */
eliminateBitCastPreds(Value * Val,std::set<Value * > & DeadInst,std::set<Value * > & Visited)2909 Value *GenXSimdCFConformance::eliminateBitCastPreds(Value *Val, std::set<Value *> &DeadInst, std::set<Value *> &Visited)
2910 {
2911   Type *EMType =
2912       IGCLLVM::FixedVectorType::get(Type::getInt1Ty(M->getContext()), 32);
2913 
2914   if (Visited.count(Val))
2915   {
2916     return EMProducers[Val];
2917   }
2918 
2919   Visited.insert(Val);
2920 
2921   if (auto BCI = dyn_cast<BitCastInst>(Val)) {
2922     IGC_ASSERT_MESSAGE(EMProducers[BCI] == BCI->getOperand(0), "Bad EM producer was saved!");
2923 
2924     DeadInst.insert(BCI);
2925     return eliminateBitCastPreds(BCI->getOperand(0), DeadInst, Visited);
2926   } else if (auto PN = dyn_cast<PHINode>(Val)) {
2927     IGC_ASSERT_MESSAGE(EMProducers[PN] == PN, "Bad EM producer was saved!");
2928 
2929     PHINode *NewPN = nullptr;
2930     if (PN->getType() != EMType) {
2931       // Different type at phi. This may happen if its incoming value
2932       // became bitcast.
2933       LLVM_DEBUG(dbgs() << "Creating new PHI for:\n" << *PN << "\n");
2934       NewPN = PHINode::Create(EMType, PN->getNumIncomingValues(), "EMTerm", PN);
2935       EMProducers[NewPN] = NewPN;
2936       // In case of cycle, we will return newly created phi
2937       EMProducers[PN] = NewPN;
2938       // Phi can become redundant after it
2939       DeadInst.insert(PN);
2940     }
2941 
2942     for (unsigned oi = 0, on = PN->getNumIncomingValues(); oi < on; ++oi) {
2943       auto EMProd = eliminateBitCastPreds(PN->getIncomingValue(oi), DeadInst, Visited);
2944       if (!NewPN) {
2945         PN->setIncomingValue(oi, EMProd);
2946         PN->setIncomingBlock(oi, PN->getIncomingBlock(oi));
2947       } else {
2948         NewPN->addIncoming(EMProd, PN->getIncomingBlock(oi));
2949       }
2950     }
2951 
2952     return NewPN ? NewPN : PN;
2953   } else if (auto C = dyn_cast<Constant>(Val)) {
2954     IGC_ASSERT_MESSAGE(C->isAllOnesValue(), "Should be checked before!");
2955     IGC_ASSERT_MESSAGE(EMProducers[C] == C, "Bad EM producer was saved!");
2956 
2957     return Constant::getAllOnesValue(EMType);
2958   } else {
2959     IGC_ASSERT(Val);
2960     IGC_ASSERT_MESSAGE(EMProducers[Val] == Val, "Bad EM producer was saved!");
2961     IGC_ASSERT_MESSAGE(Val->getType() == EMType, "Unexpected final EM producer!");
2962 
2963     return Val;
2964   }
2965 }
2966 
2967 /***********************************************************************
2968  * resolveBitCastChains : resolve EM -> (bitcast) -> EM chains
2969  *
2970  * Standard LLVM passes create such chains sometimes and it makes
2971  * SIMD CF non-conformant. Here we check this and make changes to
2972  * resolve it if possible. If it is not, SIMD CF remains non-conformant
2973  * and is lowered later.
2974  */
resolveBitCastChains()2975 void GenXSimdCFConformance::resolveBitCastChains()
2976 {
2977   LLVM_DEBUG(dbgs() << "Resolving Bitcast chains:\n");
2978 
2979   // We don't have EM values here so we have to gather them
2980   // here, too. This is because we can change EM values set
2981   // during these transformations.
2982   gatherEMVals();
2983 
2984   std::set<Value *> DeadInst;
2985   for (auto Val : EMVals) {
2986     if (auto PN = dyn_cast<PHINode>(Val.getValue())) {
2987       LLVM_DEBUG(dbgs() << "Found phi:\n" << *PN << "\n");
2988     } else if (auto BCI = dyn_cast<BitCastInst>(Val.getValue())) {
2989       LLVM_DEBUG(dbgs() << "Found bitcast:\n" << *BCI << "\n");
2990     } else
2991       continue;
2992 
2993     std::set<Value *> Visited;
2994     Instruction *I = dyn_cast<Instruction>(Val.getValue());
2995     Value *EMProd = getEMProducer(I, Visited, true);
2996 
2997     if (!EMProd) {
2998       LLVM_DEBUG(dbgs() << "!!! Not EM producer was detected when resolving bitcast chains !!!\n");
2999       continue;
3000     }
3001 
3002     Visited.clear();
3003     Value *NewEMProd = eliminateBitCastPreds(EMProd, DeadInst, Visited);
3004     if (NewEMProd != EMProd) {
3005       EMProd->replaceAllUsesWith(NewEMProd);
3006     }
3007   }
3008 
3009   EMVals.clear();
3010 
3011   for (auto DI : DeadInst) {
3012     if (auto I = dyn_cast<Instruction>(DI))
3013       RecursivelyDeleteTriviallyDeadInstructions(I);
3014   }
3015 
3016   // TODO: since we are using EMProducers only here and during get_em check,
3017   // clean it after these transformation sinse it may contain dead data.
3018   EMProducers.clear();
3019 
3020   LLVM_DEBUG(dbgs() << "Done resolving bitcast chains:\n");
3021 }
3022 
3023 /***********************************************************************
3024  * checkEMInterference : check for EM values interfering with each other,
3025  *      lowering gotos/joins as necessary
3026  *
3027  * There is only one EM in the hardware, and we need to model that by ensuring
3028  * that our multiple EM values, including phi nodes, do not interfere with each
3029  * other. This is effectively a register allocator with only one register.
3030  */
checkEMInterference()3031 void GenXSimdCFConformance::checkEMInterference()
3032 {
3033   // Do an interference check, returning a list of defs that appear in the live
3034   // range of other values.
3035   SetVector<Value *> BadDefs;
3036   checkInterference(&EMVals, &BadDefs, nullptr);
3037   for (auto i = BadDefs.begin(), e = BadDefs.end(); i != e; ++i)
3038     removeBadEMVal(*i);
3039 }
3040 
3041 /***********************************************************************
3042  * findLoweredEMValue : find lowered EM Value
3043  */
findLoweredEMValue(Value * Val)3044 Value *GenXSimdCFConformance::findLoweredEMValue(Value *Val) {
3045   LLVM_DEBUG(dbgs() << "Looking for lowered value for:\n" << *Val << "\n");
3046 
3047   auto It = LoweredEMValsMap.find(Val);
3048   if (It != LoweredEMValsMap.end()) {
3049     auto *loweredVal = It->second;
3050     LLVM_DEBUG(dbgs() << "Found lowered value:\n" << *loweredVal << "\n");
3051     return loweredVal;
3052   }
3053 
3054   LLVM_DEBUG(dbgs() << "No lowered value was found\n");
3055 
3056   return nullptr;
3057 }
3058 
3059 /***********************************************************************
3060  * buildLoweringViaGetEM : build GetEM instruction to get explicit EM
3061  *   from Val.
3062  */
buildLoweringViaGetEM(Value * Val,Instruction * InsertBefore)3063 Value *GenXSimdCFConformance::buildLoweringViaGetEM(Value *Val,
3064                                                     Instruction *InsertBefore) {
3065   Function *GetEMDecl = GenXIntrinsic::getGenXDeclaration(
3066       M, GenXIntrinsic::genx_simdcf_get_em, {Val->getType()});
3067   Value *GetEM = CallInst::Create(GetEMDecl, {Val}, "getEM", InsertBefore);
3068   LoweredEMValsMap[Val] = GetEM;
3069 
3070   LLVM_DEBUG(dbgs() << "Built getEM:\n" << *GetEM << "\n");
3071 
3072   return GetEM;
3073 }
3074 
3075 /***********************************************************************
3076  * getGetEMLoweredValue : find lowered EM Value (via GetEM) or build
3077  *   GetEM instruction if lowered value was not found.
3078  */
getGetEMLoweredValue(Value * Val,Instruction * InsertBefore)3079 Value *GenXSimdCFConformance::getGetEMLoweredValue(Value *Val,
3080                                                    Instruction *InsertBefore) {
3081   auto *GetEM = findLoweredEMValue(Val);
3082 
3083   if (!GetEM) {
3084     GetEM = buildLoweringViaGetEM(Val, InsertBefore);
3085   }
3086 
3087   return GetEM;
3088 }
3089 
3090 /***********************************************************************
3091  * lowerEVIUse : lower ExtractValue use.
3092  *
3093  * EM is being lowered via genx_simdcf_get_em intrinsic.
3094  */
lowerEVIUse(ExtractValueInst * EVI,Instruction * User,BasicBlock * PhiPredBlock)3095 Value *GenXSimdCFConformance::lowerEVIUse(ExtractValueInst *EVI,
3096                                           Instruction *User,
3097                                           BasicBlock *PhiPredBlock) {
3098   LLVM_DEBUG(dbgs() << "Lowering EVI use:\n" << *EVI << "\n");
3099 
3100   CallInst *GotoJoin = dyn_cast<CallInst>(EVI->getOperand(0));
3101   IGC_ASSERT_MESSAGE(testIsGotoJoin(GotoJoin), "Bad ExtractValue with EM!");
3102 
3103   // The CFG was corrected for SIMD CF by earlier transformations
3104   // so isBranchingGotoJoinBlock works correctly here.
3105   if (GotoJoin::isBranchingGotoJoinBlock(GotoJoin->getParent()) == GotoJoin) {
3106     // For branching case, we need to create false and true value
3107     LLVM_DEBUG(dbgs() << "Handling branching block case\n");
3108 
3109     BasicBlock *DefBB = GotoJoin->getParent();
3110     BasicBlock *TrueBlock = DefBB->getTerminator()->getSuccessor(0);
3111     BasicBlock *FalseBlock = DefBB->getTerminator()->getSuccessor(1);
3112     BasicBlock *Loc = PhiPredBlock ? PhiPredBlock : User->getParent();
3113 
3114     // GetEM is removed later if redundant.
3115     Value *TrueVal = Constant::getNullValue(EVI->getType());
3116     Value *FalseVal = getGetEMLoweredValue(EVI, FalseBlock->getFirstNonPHI());
3117 
3118     // Early return for direct phi true edge: lowered value is zeroed
3119     if (PhiPredBlock == DefBB && TrueBlock == User->getParent()) {
3120       IGC_ASSERT(PhiPredBlock);
3121       IGC_ASSERT_MESSAGE(FalseBlock != TrueBlock,
3122                          "Crit edge should be inserted earlier!");
3123       return TrueVal;
3124     }
3125 
3126     std::map<BasicBlock *, Value *> foundVals;
3127     BasicBlockEdge TrueEdge(DefBB, TrueBlock);
3128     BasicBlockEdge FalseEdge(DefBB, FalseBlock);
3129 
3130     return findGotoJoinVal(RegCategory::EM, Loc, EVI, TrueEdge, FalseEdge,
3131                            TrueVal, FalseVal, foundVals);
3132   }
3133 
3134   // Non-branching case: must be join. Insert get_em right after join's EM
3135   IGC_ASSERT_MESSAGE(testIsJoin(GotoJoin),
3136          "Gotos should be turned into branching earlier!");
3137 
3138   LLVM_DEBUG(dbgs() << "Handling simple join case\n");
3139 
3140   return getGetEMLoweredValue(EVI, EVI->getNextNode());
3141 }
3142 
3143 /***********************************************************************
3144  * lowerPHIUse : lower PHI use.
3145  *
3146  * EM is being lowered via genx_simdcf_get_em intrinsic.
3147  * This intrinsic is inserted right after the phis in current BB
3148  * in case of non-join block. For join blocks, the full PHI lowering
3149  * is performed: we have to lower all incoming values.
3150  *
3151  * Lowered phis are also stored in LoweredPhisMap to
3152  * prevent redundant lowerings.
3153  */
lowerPHIUse(PHINode * PN,SetVector<Value * > & ToRemove)3154 Value *GenXSimdCFConformance::lowerPHIUse(PHINode *PN,
3155                                           SetVector<Value *> &ToRemove) {
3156   LLVM_DEBUG(dbgs() << "Lowering PHI use:\n" << *PN << "\n");
3157 
3158   // Check if the phi was already lowered
3159   if (auto *FoundVal = findLoweredEMValue(PN)) {
3160     return FoundVal;
3161   }
3162 
3163   if (!GotoJoin::isJoinLabel(PN->getParent())) {
3164     auto res = getGetEMLoweredValue(PN, PN->getParent()->getFirstNonPHI());
3165     LLVM_DEBUG(dbgs() << "Created " << *res << "\n");
3166     return res;
3167   }
3168 
3169   LLVM_DEBUG(dbgs() << "Performing full lowering\n");
3170 
3171   // Clone phi and store it as lowered value.
3172   auto *newPN = cast<PHINode>(PN->clone());
3173   newPN->insertAfter(PN);
3174   LoweredEMValsMap[PN] = newPN;
3175 
3176   LLVM_DEBUG(dbgs() << "Cloned phi before lowering values:\n"
3177                     << *newPN << "\n");
3178 
3179   // Lower clone's preds
3180   for (unsigned idx = 0, op_no = newPN->getNumIncomingValues(); idx < op_no;
3181        ++idx) {
3182     replaceUseWithLoweredEM(newPN, idx, ToRemove);
3183   }
3184 
3185   LLVM_DEBUG(dbgs() << "Cloned phi with lowered values:\n" << *newPN << "\n");
3186 
3187   return newPN;
3188 }
3189 
3190 /***********************************************************************
3191  * lowerArgumentUse : lower argument use.
3192  *
3193  * EM is being lowered via genx_simdcf_get_em intrinsic.
3194  * Get_em is created at function enter. Lowering can be needed
3195  * if argument's user was moved under SIMD CF due to some reason.
3196  */
lowerArgumentUse(Argument * Arg)3197 Value *GenXSimdCFConformance::lowerArgumentUse(Argument *Arg) {
3198   LLVM_DEBUG(dbgs() << "Lowering argument use:\n" << *Arg << "\n");
3199 
3200   return getGetEMLoweredValue(Arg, Arg->getParent()->front().getFirstNonPHI());
3201 }
3202 
3203 /***********************************************************************
3204  * replaceUseWithLoweredEM : lower incoming EM for user.
3205  *
3206  * EM is being lowered via genx_simdcf_get_em intrinsic.
3207  */
replaceUseWithLoweredEM(Instruction * Val,unsigned operandNo,SetVector<Value * > & ToRemove)3208 void GenXSimdCFConformance::replaceUseWithLoweredEM(Instruction *Val, unsigned operandNo, SetVector<Value *> &ToRemove)
3209 {
3210   Value *EM = Val->getOperand(operandNo);
3211 
3212   LLVM_DEBUG(dbgs() << "Replacing EM use:\n" << *EM << "\nwith lowered EM for:\n" << *Val << "\n");
3213 
3214   Value *LoweredEM = nullptr;
3215 
3216   if (auto *EVI = dyn_cast<ExtractValueInst>(EM)) {
3217     BasicBlock *PhiPredBlock = nullptr;
3218     if (auto *PN = dyn_cast<PHINode>(Val))
3219       PhiPredBlock = PN->getIncomingBlock(operandNo);
3220     LoweredEM = lowerEVIUse(EVI, Val, PhiPredBlock);
3221   } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(EM)) {
3222     // Shuffle vector: go through it and lower its pred.
3223     // All changes will be applied here.
3224     replaceUseWithLoweredEM(SVI, 0, ToRemove);
3225   } else if (auto *PN = dyn_cast<PHINode>(EM)) {
3226     LoweredEM = lowerPHIUse(PN, ToRemove);
3227   } else if (auto *Arg = dyn_cast<Argument>(EM)) {
3228     LoweredEM = lowerArgumentUse(Arg);
3229   } else if (isa<Constant>(EM) && EM->getType()->getScalarType()->isIntegerTy(1)) {
3230     LoweredEM = EM;
3231   } else
3232     // All other instructions should not be EM producers with correct DF
3233     IGC_ASSERT_EXIT_MESSAGE(0, "Failed to lower EM!");
3234 
3235   if (LoweredEM)
3236     Val->setOperand(operandNo, LoweredEM);
3237 
3238   ToRemove.insert(Val);
3239 }
3240 
3241 /***********************************************************************
3242  * canUseLoweredEM : check whether instruction can use lowered EM
3243  *
3244  * Lowered EM is an explicit value that can be consumed by any
3245  * instruction except of goto and join because they take implicit EM.
3246  */
canUseLoweredEM(Instruction * Val)3247 bool GenXSimdCFConformance::canUseLoweredEM(Instruction *Val)
3248 {
3249   if (GenXIntrinsic::getGenXIntrinsicID(Val) == GenXIntrinsic::genx_simdcf_goto ||
3250     GenXIntrinsic::getGenXIntrinsicID(Val) == GenXIntrinsic::genx_simdcf_join)
3251     return false;
3252 
3253   // For phi, check that it does not deal with goto or join.
3254   if (auto PN = dyn_cast<PHINode>(Val)) {
3255     for (unsigned idx = 0, opNo = PN->getNumIncomingValues(); idx < opNo; ++idx) {
3256       auto Inst = dyn_cast<ExtractValueInst>(PN->getOperand(idx));
3257       if (Inst) {
3258         auto Pred = Inst->getOperand(0);
3259         if (GenXIntrinsic::getGenXIntrinsicID(Pred) == GenXIntrinsic::genx_simdcf_goto ||
3260           GenXIntrinsic::getGenXIntrinsicID(Pred) == GenXIntrinsic::genx_simdcf_join)
3261           return false;
3262       }
3263     }
3264   }
3265 
3266   return true;
3267 }
3268 
3269 /***********************************************************************
3270  * canUseRealEM : check whether instruction can use real EM that is
3271  * passed via #opNo operand.
3272  *
3273  * This is used to check if instruction can use real EM.
3274  *
3275  * TODO: It is used only by linearized fragment optimization now.
3276  * This function should be extended and put into getConnectedVals
3277  * algorithm in order to make the last one simplier. For now,
3278  * this check will be passed only by selects, shufflevectors and wrregions
3279  * because these instructions movement makes sence during the
3280  * optimization.
3281  */
canUseRealEM(Instruction * Inst,unsigned opNo)3282 bool GenXSimdCFConformance::canUseRealEM(Instruction *Inst, unsigned opNo) {
3283   if (auto *Select = dyn_cast<SelectInst>(Inst)) {
3284     // Real EM can be condition only
3285     return opNo == 0;
3286   }
3287 
3288   if (auto *SVI = dyn_cast<ShuffleVectorInst>(Inst)) {
3289     // TODO: getConnectedVals checks only this, but
3290     // there is no check for idxs correctness.
3291     // They should be 0, 1, 2, ..., EXEC_SIZE - 1 for
3292     // EM truncation.
3293     if (!ShuffleVectorAnalyzer::isReplicatedSlice(SVI))
3294       return false;
3295 
3296     return checkAllUsesAreSelectOrWrRegion(SVI);
3297   }
3298 
3299   // Left switch for further extensions
3300   switch (GenXIntrinsic::getGenXIntrinsicID(Inst)) {
3301   case GenXIntrinsic::genx_wrregionf:
3302   case GenXIntrinsic::genx_wrregioni:
3303     // Real EM can be wrrregion predicate only
3304     return opNo == GenXIntrinsic::GenXRegion::PredicateOperandNum;
3305   default:
3306     break;
3307   }
3308 
3309   return false;
3310 }
3311 
3312 /***********************************************************************
3313  * checkInterference : check for a list of values interfering with each other
3314  *
3315  * Enter:   Vals = values to check (not constants)
3316  *          BadDefs = SetVector in which to store any def that is found in the
3317  *                    live range of another def
3318  *          ConstStop = instruction to treat as the def point of a constantpred,
3319  *                      nullptr to treat the start of the function as the def
3320  *                      point
3321  *
3322  * This code finds interference by scanning back from uses, finding other defs,
3323  * relying on the dominance property of SSA. Having found that two EM values A
3324  * and B interfere due to the def of A appearing in the live range of B, we
3325  * could choose either one to lower its goto and join. In fact we choose A (the
3326  * found def), as that tends to lower inner SIMD CF, giving a chance for the
3327  * outer SIMD CF to become legal.
3328  *
3329  * Because GenXSimdCFConformance runs before live ranges are determined, so
3330  * that it can modify code as it wants, we cannot use the normal interference
3331  * testing code in GenXLiveness.
3332  *
3333  * The idea of ConstStop is different depending on whether we are testing
3334  * interference of all EM values, or all RM values for a particular join:
3335  *
3336  * * For interference between all EM values, any constant (input to
3337  *   constantpred intrinsic) must be all ones, which is checked elsewhere. It
3338  *   represents the state of the execution mask at the start of the function,
3339  *   therefore we need to pretend that the constantpred's live range extends
3340  *   back to the start of the function.  This is done by the caller setting
3341  *   ConstStop to 0.
3342  *
3343  * * For interference between all RM values for one particular join, any
3344  *   constant must be all zeros, which is checked elsewhere. It represents the
3345  *   state of that join's resume mask on entry to the function, and just after
3346  *   executing the join. Therefore we need to pretend that the constantpred's
3347  *   live range extends back to those two places. This is done by the caller
3348  *   setting ConstStop to the join instruction.
3349  */
checkInterference(SetVector<SimpleValue> * Vals,SetVector<Value * > * BadDefs,Instruction * ConstStop)3350 void GenXSimdCFConformance::checkInterference(SetVector<SimpleValue> *Vals,
3351     SetVector<Value *> *BadDefs, Instruction *ConstStop)
3352 {
3353   // Scan the live range of each value, looking for a def of another value.
3354   // Finding such a def indicates interference.
3355   SetVector<Value *> ToRemove;
3356   for (auto evi = Vals->begin(), eve = Vals->end(); evi != eve; ++evi) {
3357     Value *EMVal = evi->getValue();
3358     bool IsConstantPred = GenXIntrinsic::getGenXIntrinsicID(EMVal) == GenXIntrinsic::genx_constantpred;
3359     // Set of blocks where we know the value is live out.
3360     SmallSet<BasicBlock *, 8> LiveOut;
3361     // Start from each use and scan backwards. If the EMVal was affected by
3362     // transformations, there is no need to check other uses.
3363     for (auto ui = EMVal->use_begin(), ue = EMVal->use_end();
3364          ui != ue && ToRemove.count(EMVal) == 0;) {
3365       auto User = cast<Instruction>(ui->getUser());
3366       auto OpNo = ui->getOperandNo();
3367       ++ui;
3368       if (auto EVI = dyn_cast<ExtractValueInst>(User)) {
3369         // Ignore a use that is an extractvalue not involving the right struct
3370         // index.
3371         unsigned StartIndex = IndexFlattener::flatten(
3372             cast<StructType>(EVI->getOperand(0)->getType()), EVI->getIndices());
3373         unsigned NumIndices = IndexFlattener::getNumElements(EVI->getType());
3374         if (evi->getIndex() - StartIndex >= NumIndices)
3375           continue;
3376       }
3377       BasicBlock *PhiPred = nullptr;
3378       if (auto Phi = dyn_cast<PHINode>(User))
3379         PhiPred = Phi->getIncomingBlock(OpNo);
3380       auto Inst = User;
3381       SmallVector<BasicBlock *, 4> PendingBBStack;
3382       for (;;) {
3383         if (!Inst) {
3384           // Go on to the next pending predecessor.
3385           if (PendingBBStack.empty())
3386             break;
3387           Inst = PendingBBStack.back()->getTerminator();
3388           PendingBBStack.pop_back();
3389         }
3390         if (&Inst->getParent()->front() == Inst) {
3391           // Reached the start of the block. Make all unprocessed predecessors
3392           // pending. Except if the use is in a phi node and this is the first
3393           // time we reach the start of a block: in that case, mark only the
3394           // corresponding phi block is pending.
3395           if (PhiPred) {
3396             if (LiveOut.insert(PhiPred).second)
3397               PendingBBStack.push_back(PhiPred);
3398             PhiPred = nullptr;
3399           } else {
3400             BasicBlock *InstBB = Inst->getParent();
3401             std::copy_if(pred_begin(InstBB), pred_end(InstBB),
3402                          std::back_inserter(PendingBBStack),
3403                          [&LiveOut](BasicBlock *BB) {
3404                            return LiveOut.insert(BB).second;
3405                          });
3406           }
3407           Inst = nullptr;
3408           continue;
3409         }
3410         // Go back to the previous instruction. (This happens even when
3411         // starting at the end of a new block, thus skipping scanning the uses
3412         // of the terminator, but that's OK because the terminator never uses
3413         // our EM or RM values.)
3414         Inst = Inst->getPrevNode();
3415         if (Inst == EMVal && !IsConstantPred) {
3416           // Reached the def of the value. Stop scanning, unless the def is
3417           // constantpred, in which case we pretend it was live from the
3418           // ConstStop.
3419           Inst = nullptr;
3420           continue;
3421         }
3422         if (Inst == ConstStop && IsConstantPred) {
3423           // For a constantpred value, we have reached the point that we want
3424           // to treat as its definition point.  Stop scanning.
3425           Inst = nullptr;
3426           continue;
3427         }
3428         // Check if this is the def of some other EM value.
3429         if (auto VT = dyn_cast<VectorType>(Inst->getType()))
3430           if (VT->getElementType()->isIntegerTy(1))
3431             if (Vals->count(Inst) && !ToRemove.count(Inst)) {
3432               // It is the def of some other EM value. Mark that one as
3433               // interfering. However do not mark it if both values are
3434               // constantpred, since we pretend all of those are defined at the
3435               // start of the function.
3436               if (!IsConstantPred
3437                   || GenXIntrinsic::getGenXIntrinsicID(Inst) != GenXIntrinsic::genx_constantpred) {
3438                 LLVM_DEBUG(dbgs() << "GenXSimdCFConformance::checkInterference: def of " << Inst->getName() << " found in live range of " << EMVal->getName() << "\n");
3439                 auto SVI = dyn_cast<ShuffleVectorInst>(Inst);
3440                 if (SVI && SVI->getOperand(0) == EMVal) {
3441                   // Shuffle vector is baled as EM of another size: this check is to
3442                   // ensure that the EM in SVI is still actual
3443                   LLVM_DEBUG(dbgs() << "\tShuffle vector with correct arg, skipping it\n");
3444                 } else if (canUseLoweredEM(User) && !FG) {
3445                   // Lower EM in Early Pass
3446                   replaceUseWithLoweredEM(User, OpNo, ToRemove);
3447                   LLVM_DEBUG(dbgs() << "\tSucceded to lower EM for that use\n");
3448                 } else {
3449                   LLVM_DEBUG(dbgs() << "\t!!! Failed to lower EM for that use: def will be lowered\n");
3450                   BadDefs->insert(Inst);
3451                 }
3452                 // Done for that use
3453                 break;
3454               }
3455             }
3456       }
3457     }
3458   }
3459 
3460   for (auto Inst : ToRemove) {
3461     removeFromEMRMVals(Inst);
3462   }
3463 }
3464 
3465 /***********************************************************************
3466  * insertCond : insert a vector of i1 value into the start of another one
3467  *
3468  * Enter:   OldVal = value to insert into
3469  *          NewVal = value to insert, at index 0
3470  *          Name = name for any new instruction
3471  *          InsertBefore = where to insert any new instruction
3472  *          DL = debug loc to give any new instruction
3473  *
3474  * Return:  value, possibly the same as the input value
3475  */
insertCond(Value * OldVal,Value * NewVal,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)3476 Value *GenXSimdCFConformance::insertCond(Value *OldVal, Value *NewVal,
3477     const Twine &Name, Instruction *InsertBefore, const DebugLoc &DL)
3478 {
3479   unsigned OldWidth =
3480       cast<IGCLLVM::FixedVectorType>(OldVal->getType())->getNumElements();
3481   unsigned NewWidth =
3482       cast<IGCLLVM::FixedVectorType>(NewVal->getType())->getNumElements();
3483   if (OldWidth == NewWidth)
3484     return NewVal;
3485   // Do the insert with shufflevector. We need two shufflevectors, one to extend
3486   // NewVal to OldVal's width, and one to combine them.
3487   // GenXLowering decides whether this is suitable to lower to wrpredregion, or
3488   // needs to be lowered to something less efficient.
3489   SmallVector<Constant *, 32> Indices;
3490   Type *I32Ty = Type::getInt32Ty(InsertBefore->getContext());
3491   unsigned i;
3492   for (i = 0; i != NewWidth; ++i)
3493     Indices.push_back(ConstantInt::get(I32Ty, i));
3494   auto UndefIndex = UndefValue::get(I32Ty);
3495   for (; i != OldWidth; ++i)
3496     Indices.push_back(UndefIndex);
3497   auto SV1 = new ShuffleVectorInst(NewVal, UndefValue::get(NewVal->getType()),
3498       ConstantVector::get(Indices), NewVal->getName() + ".extend", InsertBefore);
3499   SV1->setDebugLoc(DL);
3500   if (isa<UndefValue>(OldVal))
3501     return SV1;
3502   Indices.clear();
3503   for (i = 0; i != NewWidth; ++i)
3504     Indices.push_back(ConstantInt::get(I32Ty, i + OldWidth));
3505   for (; i != OldWidth; ++i)
3506     Indices.push_back(ConstantInt::get(I32Ty, i));
3507   auto SV2 = new ShuffleVectorInst(OldVal, SV1, ConstantVector::get(Indices),
3508       Name, InsertBefore);
3509   SV2->setDebugLoc(DL);
3510   return SV2;
3511 }
3512 
3513 /***********************************************************************
3514  * truncateCond : truncate a vector of i1 value
3515  *
3516  * Enter:   In = input value
3517  *          Ty = type to truncate to
3518  *          Name = name for any new instruction
3519  *          InsertBefore = where to insert any new instruction
3520  *          DL = debug loc to give any new instruction
3521  *
3522  * Return:  value, possibly the same as the input value
3523  */
truncateCond(Value * In,Type * Ty,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)3524 Value *GenXSimdCFConformance::truncateCond(Value *In, Type *Ty,
3525     const Twine &Name, Instruction *InsertBefore, const DebugLoc &DL)
3526 {
3527   unsigned InWidth =
3528       cast<IGCLLVM::FixedVectorType>(In->getType())->getNumElements();
3529   unsigned TruncWidth = cast<IGCLLVM::FixedVectorType>(Ty)->getNumElements();
3530   if (InWidth == TruncWidth)
3531     return In;
3532   // Do the truncate with shufflevector. GenXLowering lowers it to rdpredregion.
3533   SmallVector<Constant *, 32> Indices;
3534   Type *I32Ty = Type::getInt32Ty(InsertBefore->getContext());
3535   unsigned i;
3536   for (i = 0; i != TruncWidth; ++i)
3537     Indices.push_back(ConstantInt::get(I32Ty, i));
3538   auto SV = new ShuffleVectorInst(In, UndefValue::get(In->getType()),
3539       ConstantVector::get(Indices), Name, InsertBefore);
3540   SV->setDebugLoc(DL);
3541   return SV;
3542 }
3543 
3544 /***********************************************************************
3545  * lowerGoto : lower a llvm.genx.simdcf.goto
3546  *
3547  * This also outputs a warning that we failed to optimize a SIMD branch.
3548  * We always output it, rather than including it in the -rpass mechanism
3549  * to enable or disable the warning, as it is an unexpected situation that
3550  * we want our users to report.
3551  */
lowerGoto(CallInst * Goto)3552 void GenXSimdCFConformance::lowerGoto(CallInst *Goto)
3553 {
3554   LLVM_DEBUG(dbgs() << "lowerGoto: " << *Goto << "\n");
3555   const DebugLoc &DL = Goto->getDebugLoc();
3556   if (EnableGenXGotoJoin && !lowerSimdCF)
3557     DiagnosticInfoSimdCF::emit(Goto, "failed to optimize SIMD branch", DS_Warning);
3558   Value *Results[3];
3559   auto EM = Goto->getOperand(0);
3560   auto Cond = Goto->getOperand(2);
3561   // EM is always 32 bit. Extract SubEM, of the same width as Cond, from it.
3562   auto OldSubEM = truncateCond(EM, Cond->getType(),
3563       EM->getName() + ".sub", Goto, DL);
3564   // Result 1: NewRM = OldRM | (SubEM & ~Cond)
3565   auto NotCond = BinaryOperator::Create(Instruction::Xor, Cond,
3566       Constant::getAllOnesValue(Cond->getType()),
3567       Goto->getName() + ".notcond", Goto);
3568   NotCond->setDebugLoc(DL);
3569   auto NotCondAndSubEM = BinaryOperator::Create(Instruction::And, NotCond,
3570       OldSubEM, Goto->getName() + ".disabling", Goto);
3571   NotCondAndSubEM->setDebugLoc(DL);
3572   Value *OldRM = Goto->getArgOperand(1);
3573   auto NewRM = BinaryOperator::Create(Instruction::Or, OldRM, NotCondAndSubEM,
3574       Goto->getName() + ".newRM", Goto);
3575   NewRM->setDebugLoc(DL);
3576   Results[1] = NewRM;
3577   // And SubEM with Cond.
3578   auto SubEM = BinaryOperator::Create(Instruction::And, OldSubEM, Cond,
3579       Goto->getName() + ".subEM", Goto);
3580   SubEM->setDebugLoc(DL);
3581   // Insert that back into EM. That is result 0.
3582   Results[0] = EM = insertCond(EM, SubEM, Goto->getName() + ".EM", Goto, DL);
3583   // Result 2: BranchCond = !any(SubEM)
3584   Function *AnyFunc = GenXIntrinsic::getGenXDeclaration(M, GenXIntrinsic::genx_any,
3585       SubEM->getType());
3586   auto Any = CallInst::Create(AnyFunc, SubEM,
3587       SubEM->getName() + ".any", Goto);
3588   Any->setDebugLoc(DL);
3589   auto Not = BinaryOperator::Create(Instruction::Xor, Any,
3590       Constant::getAllOnesValue(Any->getType()),
3591       Any->getName() + ".not", Goto);
3592   Not->setDebugLoc(DL);
3593   Results[2] = Not;
3594   // Replace uses.
3595   replaceGotoJoinUses(Goto, Results);
3596   Goto->eraseFromParent();
3597   Modified = true;
3598 }
3599 
3600 /***********************************************************************
3601  * lowerJoin : lower a llvm.genx.simdcf.join
3602  */
lowerJoin(CallInst * Join)3603 void GenXSimdCFConformance::lowerJoin(CallInst *Join)
3604 {
3605   LLVM_DEBUG(dbgs() << "lowerJoin: " << *Join << "\n");
3606   const DebugLoc &DL = Join->getDebugLoc();
3607   Value *Results[2];
3608   auto EM = Join->getOperand(0);
3609   auto RM = Join->getOperand(1);
3610   // EM is always 32 bit. Extract SubEM, of the same width as RM, from it.
3611   auto OldSubEM = truncateCond(EM, RM->getType(), EM->getName() + ".sub",
3612       Join, DL);
3613   // Or it with RM.
3614   auto SubEM = BinaryOperator::Create(Instruction::Or, OldSubEM, RM,
3615       Join->getName() + ".subEM", Join);
3616   SubEM->setDebugLoc(DL);
3617   // Insert that back into EM. That is result 0.
3618   Results[0] = EM = insertCond(EM, SubEM, Join->getName() + ".EM", Join, DL);
3619   // Result 1: BranchCond = !any(SubEM)
3620   Function *AnyFunc = GenXIntrinsic::getGenXDeclaration(M, GenXIntrinsic::genx_any,
3621       SubEM->getType());
3622   auto Any = CallInst::Create(AnyFunc, SubEM,
3623       SubEM->getName() + ".any", Join);
3624   Any->setDebugLoc(DL);
3625   auto Not = BinaryOperator::Create(Instruction::Xor, Any,
3626       Constant::getAllOnesValue(Any->getType()),
3627       Any->getName() + ".not", Join);
3628   Not->setDebugLoc(DL);
3629   Results[1] = Not;
3630   // Replace uses.
3631   replaceGotoJoinUses(Join, Results);
3632   Join->eraseFromParent();
3633   Modified = true;
3634 }
3635 
3636 /***********************************************************************
3637  * replaceGotoJoinUses : replace uses of goto/join
3638  *
3639  * The goto and join intrinsics have multiple return values in a struct.
3640  * This attempts to find the extractvalues and replace those directly.
3641  * It also spots where a value is unused.
3642  */
replaceGotoJoinUses(CallInst * GotoJoin,ArrayRef<Value * > Vals)3643 void GenXSimdCFConformance::replaceGotoJoinUses(CallInst *GotoJoin,
3644     ArrayRef<Value *> Vals)
3645 {
3646   SmallVector<ExtractValueInst *, 4> Extracts;
3647   for (auto ui = GotoJoin->use_begin(), ue = GotoJoin->use_end();
3648       ui != ue; ++ui) {
3649     auto Extract = dyn_cast<ExtractValueInst>(ui->getUser());
3650     if (Extract)
3651       Extracts.push_back(Extract);
3652   }
3653   for (auto ei = Extracts.begin(), ee = Extracts.end(); ei != ee; ++ei) {
3654     auto Extract = *ei;
3655     unsigned Index = Extract->getIndices()[0];
3656     if (Index >= Vals.size())
3657       continue;
3658     Extract->replaceAllUsesWith(Vals[Index]);
3659     Extract->eraseFromParent();
3660   }
3661   if (!GotoJoin->use_empty()) {
3662     // There are still some uses of the original goto/join. We need to
3663     // aggregate the result values into a struct.
3664     Value *StructVal = UndefValue::get(GotoJoin->getType());
3665     Instruction *InsertBefore = GotoJoin->getNextNode();
3666     for (unsigned Index = 0,
3667         End = cast<StructType>(GotoJoin->getType())->getNumElements();
3668         Index != End; ++Index)
3669       StructVal = InsertValueInst::Create(StructVal, Vals[Index],
3670           Index, "", InsertBefore);
3671     GotoJoin->replaceAllUsesWith(StructVal);
3672   } else {
3673     // Remove code for unused value. This is particularly useful at an outer
3674     // join, where the !any(NewEM) is unused, so we don't need to compute it.
3675     for (unsigned vi = 0; vi != Vals.size(); ++vi) {
3676       Value *V = Vals[vi];
3677       while (V && V->use_empty()) {
3678         auto I = dyn_cast<Instruction>(V);
3679         if (I == nullptr)
3680           continue;
3681         unsigned NumOperands = I->getNumOperands();
3682         if (auto CI = dyn_cast<CallInst>(I))
3683           NumOperands = CI->getNumArgOperands();
3684         V = nullptr;
3685         if (NumOperands == 1)
3686           V = I->getOperand(0);
3687         I->eraseFromParent();
3688       }
3689     }
3690   }
3691 }
3692 
3693 /***********************************************************************
3694  * fixBlockDataBeforeRemoval : clear redundant phi-s and dbg
3695  *    instruction, before erase basic block
3696  */
fixBlockDataBeforeRemoval(BasicBlock * BB,BasicBlock * SuccBB)3697 static void fixBlockDataBeforeRemoval(BasicBlock *BB, BasicBlock *SuccBB) {
3698   while (auto *PN = dyn_cast<PHINode>(BB->begin()))
3699     PN->eraseFromParent();
3700 
3701   IGC_ASSERT_MESSAGE(BB->getSingleSuccessor() == SuccBB,
3702                      "Awaiting only one successor");
3703   bool HasOnePred = SuccBB->hasNPredecessors(1);
3704   while (auto *DBG = dyn_cast<llvm::DbgVariableIntrinsic>(BB->begin())) {
3705     DBG->moveBefore(SuccBB->getFirstNonPHI());
3706     if (!HasOnePred)
3707       IGCLLVM::setDbgVariableLocationToUndef(DBG);
3708   }
3709 }
3710 
3711 /***********************************************************************
3712  * setCategories : set webs of EM and RM values to category EM or RM
3713  *
3714  * This also modifies EM uses as needed.
3715  */
setCategories()3716 void GenXLateSimdCFConformance::setCategories()
3717 {
3718   // First the EM values.
3719   for (auto ei = EMVals.begin(); ei != EMVals.end(); /* empty */) {
3720     SimpleValue EMVal = *ei;
3721     ei++;
3722     // For this EM value, set its category and modify its uses.
3723     Liveness->getOrCreateLiveRange(EMVal)->setCategory(RegCategory::EM);
3724     LLVM_DEBUG(dbgs() << "Set category for:\n" << *EMVal.getValue() << "\n");
3725     if (!isa<StructType>(EMVal.getValue()->getType()))
3726       modifyEMUses(EMVal.getValue());
3727     switch (GenXIntrinsic::getGenXIntrinsicID(EMVal.getValue())) {
3728       case GenXIntrinsic::genx_simdcf_join: {
3729         // For a join, set the category of each RM value.
3730         auto RMValsEntry = &RMVals[cast<CallInst>(EMVal.getValue())];
3731         for (auto vi = RMValsEntry->begin(), ve = RMValsEntry->end(); vi != ve; ++vi) {
3732           SimpleValue RMVal = *vi;
3733           // For this RM value, set its category.
3734           Liveness->getOrCreateLiveRange(RMVal)->setCategory(RegCategory::RM);
3735         }
3736       }
3737       // Fall through...
3738       case GenXIntrinsic::genx_simdcf_goto: {
3739         // See if this is a branching goto/join where the "true" successor is
3740         // an empty critical edge splitter block.
3741         auto CI = cast<CallInst>(EMVal.getValue());
3742         BasicBlock *BB = CI->getParent();
3743         if (GotoJoin::isBranchingGotoJoinBlock(BB) == CI) {
3744           BasicBlock *TrueSucc = BB->getTerminator()->getSuccessor(0);
3745           if (BasicBlock *TrueSuccSucc
3746               = getEmptyCriticalEdgeSplitterSuccessor(TrueSucc)) {
3747             for (PHINode &Phi : TrueSucc->phis()) {
3748               if (Phi.getNumIncomingValues() == 1) {
3749                 auto *PredInst = Phi.getIncomingValue(0);
3750                 Phi.replaceAllUsesWith(PredInst);
3751                 Liveness->eraseLiveRange(&Phi);
3752                 removeFromEMRMVals(&Phi);
3753               } else {
3754                 IGC_ASSERT_MESSAGE(true, "BB has unremovable phi");
3755               }
3756             }
3757             // Remove phi and move dbg-info
3758             fixBlockDataBeforeRemoval(TrueSucc, TrueSuccSucc);
3759             IGC_ASSERT_MESSAGE(TrueSucc->front().isTerminator(),
3760               "BB is not empty for removal");
3761             // For a branching goto/join where the "true" successor is an empty
3762             // critical edge splitter block, remove the empty block, to ensure
3763             // that the "true" successor is a join label.
3764             // Adjust phi nodes in TrueSuccSucc.
3765             adjustPhiNodesForBlockRemoval(TrueSuccSucc, TrueSucc);
3766             // Replace the use (we know there is only the one).
3767             BB->getTerminator()->setSuccessor(0, TrueSuccSucc);
3768             // Erase the critical edge splitter block.
3769             TrueSucc->eraseFromParent();
3770             Modified = true;
3771           }
3772         }
3773         break;
3774       }
3775       default:
3776         break;
3777     }
3778   }
3779 }
3780 
3781 /***********************************************************************
3782  * modifyEMUses : modify EM uses as needed
3783  */
modifyEMUses(Value * EM)3784 void GenXLateSimdCFConformance::modifyEMUses(Value *EM)
3785 {
3786   LLVM_DEBUG(dbgs() << "modifyEMUses: " << EM->getName() << "\n");
3787   // Gather the selects we need to modify, at the same time as handling other
3788   // uses of the EM values.
3789   SmallVector <SelectInst *, 4> Selects;
3790   SmallVector <Value *, 4> EMs;
3791   EMs.push_back(EM);
3792   for (unsigned ei = 0; ei != EMs.size(); ++ei) {
3793     EM = EMs[ei];
3794     // Scan EM's uses.
3795     for (auto ui = EM->use_begin(), ue = EM->use_end(); ui != ue; ++ui) {
3796       auto User = cast<Instruction>(ui->getUser());
3797       if (auto Sel = dyn_cast<SelectInst>(User)) {
3798         IGC_ASSERT(!ui->getOperandNo());
3799         Selects.push_back(Sel);
3800       } else {
3801         IGC_ASSERT(testIsValidEMUse(User, ui));
3802         if (GenXIntrinsic::getAnyIntrinsicID(User) ==
3803           GenXIntrinsic::genx_rdpredregion) {
3804           // An rdpredregion of the EM. Find its uses in select too.
3805           EMs.push_back(User);
3806         }
3807       }
3808     }
3809   }
3810   // Modify each select into a predicated wrregion.
3811   const GenXSubtarget &Subtarget = getAnalysis<TargetPassConfig>()
3812                                        .getTM<GenXTargetMachine>()
3813                                        .getGenXSubtarget();
3814   const DataLayout &DL = M->getDataLayout();
3815   for (auto si = Selects.begin(), se = Selects.end(); si != se; ++si) {
3816     auto Sel = *si;
3817     Value *FalseVal = Sel->getFalseValue();
3818 
3819     // This code removes redundancy introduced by
3820     // select & phi lying within the same goto-join region
3821     // and effectively duplicating the work.
3822     bool LoadFalseVal = true;
3823     Instruction *Goto = nullptr;
3824     if (auto *ExtrCond = dyn_cast<ExtractValueInst>(Sel->getCondition()))
3825       Goto = dyn_cast<Instruction>(ExtrCond->getAggregateOperand());
3826     for (auto *U : Sel->users()) {
3827       auto *SelPhiUser = dyn_cast<PHINode>(U);
3828       if (!SelPhiUser)
3829         continue;
3830       DominatorTree *DomTree = getDomTree(Sel->getFunction());
3831       // NOTE: we should expect exactly 2 incoming blocks,
3832       // but sometimes we may have more due to both
3833       // NoCondEV cases and critical edge splitting,
3834       // and that should not affect correctness of the transformation
3835       // IGC_ASSERT(PhiUser->getNumIncomingValues() == 2);
3836       if (auto *DomBB = DomTree->findNearestCommonDominator(
3837               SelPhiUser->getIncomingBlock(0),
3838               SelPhiUser->getIncomingBlock(1))) {
3839         auto *Term = dyn_cast<BranchInst>(DomBB->getTerminator());
3840         if (Term && Term->isConditional()) {
3841           auto *ExtrCond = dyn_cast<ExtractValueInst>(Term->getCondition());
3842           if (ExtrCond && ExtrCond->getAggregateOperand() == Goto) {
3843             LoadFalseVal = false;
3844             break;
3845           }
3846         }
3847       }
3848     }
3849 
3850     if (auto C = dyn_cast<Constant>(FalseVal)) {
3851       if (!isa<UndefValue>(C)) {
3852         if (LoadFalseVal) {
3853           // The false value needs loading if it is a constant other than
3854           // undef.
3855           SmallVector<Instruction *, 4> AddedInstructions;
3856           FalseVal =
3857               ConstantLoader(C, Subtarget, DL, nullptr, &AddedInstructions)
3858                   .loadBig(Sel);
3859           // ConstantLoader generated at least one instruction.  Ensure that
3860           // each one has debug loc and category.
3861           for (auto aii = AddedInstructions.begin(),
3862                     aie = AddedInstructions.end();
3863                aii != aie; ++aii) {
3864             Instruction *I = *aii;
3865             I->setDebugLoc(Sel->getDebugLoc());
3866           }
3867         } else
3868           // As mentioned above, we're trying to eliminate
3869           // redundancy with select+phi in a goto/join region.
3870           // So we convert select to a wrr with an undef source
3871           // for it to effectively become a simple mov
3872           FalseVal = UndefValue::get(C->getType());
3873       }
3874     }
3875     Region R(Sel);
3876     R.Mask = Sel->getCondition();
3877     IGC_ASSERT(FalseVal);
3878     Value *Wr = R.createWrRegion(FalseVal, Sel->getTrueValue(),
3879           Sel->getName(), Sel, Sel->getDebugLoc());
3880     Sel->replaceAllUsesWith(Wr);
3881     Liveness->eraseLiveRange(Sel);
3882     Sel->eraseFromParent();
3883     Modified = true;
3884   }
3885 }
3886 
3887 /***********************************************************************
3888  * optimizeRestoredSIMDCF : perform optimization on restored SIMD CF
3889  *
3890  * Restored SIMD CF is built from linear code blocks that came from
3891  * llvm transformations. Some code could be moved from SIMD CF after
3892  * join point during this transformations. This function tries to
3893  * put such code back.
3894  *
3895  * TODO: some other transformations could be applied after SIMD CF was
3896  * linearized. Maybe this function should be updated in future.
3897  */
optimizeRestoredSIMDCF()3898 void GenXSimdCFConformance::optimizeRestoredSIMDCF() {
3899   for (auto Data : BlocksToOptimize) {
3900     // Skip blocks with lowered EM values
3901     if (!EMVals.count(SimpleValue(Data.second.getRealEM(), 0))) {
3902       LLVM_DEBUG(dbgs() << "optimizeRestoredSIMDCF: skipping "
3903                         << Data.first->getName() << "\n");
3904       continue;
3905     }
3906     optimizeLinearization(Data.first, Data.second);
3907   }
3908 }
3909 
3910 /***********************************************************************
3911  * isActualStoredEM : check if Inst is a actual stored EM
3912  *
3913  * This function is called during linear fragment optimization.
3914  * The actual stored EM is a PHI node with const/getEM inputs here.
3915  * Actuallity is checked via EM-getEM map.
3916  */
isActualStoredEM(Instruction * Inst,JoinPointOptData & JPData)3917 bool GenXSimdCFConformance::isActualStoredEM(Instruction *Inst,
3918                                              JoinPointOptData &JPData) {
3919   LLVM_DEBUG(dbgs() << "isActualStoredEM: visiting\n" << *Inst << "\n");
3920   PHINode *PN = dyn_cast<PHINode>(Inst);
3921 
3922   // Linearized block should be turned into a hammock: stored EM
3923   // must come via PHI with two preds. Go through shufflevector
3924   // in case of truncated EM.
3925   if (auto *SVI = dyn_cast<ShuffleVectorInst>(Inst)) {
3926     LLVM_DEBUG(dbgs() << "Truncated EM detected\n");
3927 
3928     // Check SVI trunc correctness
3929     if (!canUseRealEM(Inst, 0)) {
3930       LLVM_DEBUG(dbgs() << "Bad trunc via SVI: not an actual EM\n");
3931       return false;
3932     }
3933 
3934     PN = dyn_cast<PHINode>(SVI->getOperand(0));
3935   }
3936   if (!PN || PN->getNumIncomingValues() != 2) {
3937     LLVM_DEBUG(dbgs() << "Incompatable inst: not an actual EM\n");
3938     return false;
3939   }
3940 
3941   Value *ExpectedGetEM = PN->getIncomingValueForBlock(JPData.getFalsePred());
3942   Value *ExpectedConstEM = PN->getIncomingValueForBlock(JPData.getTruePred());
3943 
3944   IGC_ASSERT_MESSAGE(ExpectedGetEM, "Bad phi in hammock!");
3945   IGC_ASSERT_MESSAGE(ExpectedConstEM, "Bad phi in hammock!");
3946 
3947   // Find stored value
3948   auto It = LoweredEMValsMap.find(JPData.getRealEM());
3949   if (It == LoweredEMValsMap.end()) {
3950     LLVM_DEBUG(dbgs() << "No EM was stored: not an actual EM\n");
3951     return false;
3952   }
3953 
3954   // Check if the val from SIMD BB is a stored via get.em EM
3955   if (ExpectedGetEM != It->second) {
3956     LLVM_DEBUG(
3957         dbgs() << "SIMD BB value is not a correct get.em: not an actual EM\n");
3958     return false;
3959   }
3960 
3961   // Check if the val from True BB is an all null constant
3962   if (ExpectedConstEM != Constant::getNullValue(ExpectedConstEM->getType())) {
3963     LLVM_DEBUG(
3964         dbgs()
3965         << "True BB value is not a correct constant: not an actual EM\n");
3966     return false;
3967   }
3968 
3969   LLVM_DEBUG(dbgs() << "All checks passed\n");
3970   return true;
3971 }
3972 
3973 /***********************************************************************
3974  * canBeMovedUnderSIMDCF : check if Instruction can be moved under
3975  * SIMD CF
3976  *
3977  * This function is called during linear fragment optimization.
3978  * We can move instruction if such movement does not corrupt
3979  * dominance. Sometimes we can meet several instruction that
3980  * should be moved. There is a recursive call, all instructions
3981  * in chain are stored in Visited set.
3982  */
canBeMovedUnderSIMDCF(Value * Val,BasicBlock * CurrBB,JoinPointOptData & JPData,std::set<Instruction * > & Visited)3983 bool GenXSimdCFConformance::canBeMovedUnderSIMDCF(
3984     Value *Val, BasicBlock *CurrBB, JoinPointOptData &JPData,
3985     std::set<Instruction *> &Visited) {
3986   Instruction *Inst = dyn_cast<Instruction>(Val);
3987 
3988   // Can be non-inst. In this case we have nothing to check.
3989   if (!Inst)
3990     return true;
3991 
3992   LLVM_DEBUG(dbgs() << "canBeMovedUnderSIMDCF: visiting\n" << *Inst << "\n");
3993 
3994   // Mark instruction as visited. Return if it was already added to set:
3995   // we don't expect it to be here.
3996   if (!Visited.insert(Inst).second) {
3997     LLVM_DEBUG(dbgs() << "Instruction was already visited: do not move\n");
3998     return false;
3999   }
4000 
4001   // Instruction is not located in linearized fragment
4002   if (Inst->getParent() != CurrBB) {
4003     LLVM_DEBUG(dbgs() << "Out of linearized fragment: do not move\n");
4004     return false;
4005   }
4006 
4007   // Do not move join instruction
4008   if (GenXIntrinsic::getGenXIntrinsicID(Inst) ==
4009       GenXIntrinsic::genx_simdcf_join) {
4010     LLVM_DEBUG(dbgs() << "Join instruction: do not move\n");
4011     return false;
4012   }
4013 
4014   // TODO: current assumption is that nothing except linearization was applied
4015   // Skip instruction that has more than one user
4016   if (!Inst->hasOneUse()) {
4017     LLVM_DEBUG(dbgs() << "More than one user: do not move\n");
4018     return false;
4019   }
4020 
4021   // Check operands
4022   for (unsigned i = 0, e = Inst->getNumOperands(); i < e; ++i) {
4023     Instruction *Pred = dyn_cast<Instruction>(Inst->getOperand(i));
4024 
4025     // Not an instruction: not blocking moving
4026     if (!Pred)
4027       continue;
4028 
4029     // Check for dominance
4030     DominatorTree *DomTree = getDomTree(CurrBB->getParent());
4031     if (DomTree->dominates(Pred, JPData.getFalsePred()))
4032       continue;
4033 
4034     // Check for actual saved EM: it is a phi located in current BB,
4035     // so the dominance check failed
4036     if (isActualStoredEM(Pred, JPData))
4037       continue;
4038 
4039     // Dominance check and EM check failed: instruction is inside this block
4040     LLVM_DEBUG(dbgs() << "Recursive call for operand #" << i << "\n");
4041     if (canBeMovedUnderSIMDCF(Pred, CurrBB, JPData, Visited))
4042       continue;
4043 
4044     // Recursive call failed: do not move
4045     LLVM_DEBUG(dbgs() << "canBeMovedUnderSIMDCF: bad operand: do not move\n");
4046     return false;
4047   }
4048 
4049   LLVM_DEBUG(dbgs() << "canBeMovedUnderSIMDCF: move\n" << *Inst << "\n");
4050   return true;
4051 }
4052 
4053 /***********************************************************************
4054  * isSelectConditionCondEV : check if Select's condition is a stored
4055  * Cond EV value.
4056  *
4057  * This function is called during linear fragment optimization.
4058  * Linear fragment optimization bases on the fact that LLVM performed
4059  * code movement with PHI -> select transformation. This function
4060  * checks if the select condition is a CondEV from previous SIMD
4061  * branching instruction.
4062  *
4063  * This function also can handle constant vectorization if it
4064  * was applied: it doesn't break SIMD CF CondEV semantics.
4065  */
isSelectConditionCondEV(SelectInst * Sel,JoinPointOptData & JPData)4066 bool GenXSimdCFConformance::isSelectConditionCondEV(SelectInst *Sel,
4067                                                     JoinPointOptData &JPData) {
4068   PHINode *PN = dyn_cast<PHINode>(Sel->getCondition());
4069   if (!PN)
4070     return false;
4071 
4072   // CondEV Phi must be in the same BB
4073   if (PN->getParent() != Sel->getParent())
4074     return false;
4075 
4076   Value *TrueBlockValue = PN->getIncomingValueForBlock(JPData.getTruePred());
4077   Value *FalseBlockValue = PN->getIncomingValueForBlock(JPData.getFalsePred());
4078 
4079   IGC_ASSERT_MESSAGE(TrueBlockValue, "Bad phi in hammock!");
4080   IGC_ASSERT_MESSAGE(FalseBlockValue, "Bad phi in hammock!");
4081 
4082   Constant *TrueBlockConst = dyn_cast<Constant>(TrueBlockValue);
4083   Constant *FalseBlockConst = dyn_cast<Constant>(FalseBlockValue);
4084 
4085   if (!TrueBlockConst || !FalseBlockConst)
4086     return false;
4087 
4088   // It is not necessary to check constant type due CondEV semantics
4089   if (!TrueBlockConst->isOneValue() || !FalseBlockConst->isNullValue())
4090     return false;
4091 
4092   return true;
4093 }
4094 
4095 /***********************************************************************
4096  * replaceGetEMUse : find and replace GetEM uses to fix dominance.
4097  *
4098  * This function is called during linear fragment optimization.
4099  * After we moved Inst to SIMD BB, we need to update EM connections
4100  * according to updated DF.
4101  *
4102  * In many cases we can place real EM instead of lowered one.
4103  * GetEM may become redundant - it will be removed later in this pass.
4104  *
4105  * Note: SIMD CF will be non-conformant if Inst is left at the point
4106  * where new EM was generated. Inst is moved after that replacement
4107  * in linearized fragment optimization so conformance is not broken.
4108  */
replaceGetEMUse(Instruction * Inst,JoinPointOptData & JPData)4109 void GenXSimdCFConformance::replaceGetEMUse(Instruction *Inst,
4110                                             JoinPointOptData &JPData) {
4111   for (unsigned i = 0, e = Inst->getNumOperands(); i < e; ++i) {
4112     Instruction *Pred = dyn_cast<Instruction>(Inst->getOperand(i));
4113 
4114     if (!Pred)
4115       continue;
4116 
4117     // EM must be in the same BB
4118     if (Pred->getParent() != Inst->getParent())
4119       continue;
4120 
4121     if (!isActualStoredEM(Pred, JPData))
4122       continue;
4123 
4124     if (canUseRealEM(Inst, i)) {
4125       // Replace with real EM
4126       Instruction *NewOp = JPData.getRealEM();
4127       Instruction *FullEM = nullptr;
4128       if (isa<ShuffleVectorInst>(Pred)) {
4129         // Copy truncation via SVI
4130         NewOp = Pred->clone();
4131         NewOp->insertBefore(JPData.getFalsePred()->getTerminator());
4132         NewOp->setOperand(0, JPData.getRealEM());
4133         FullEM = cast<Instruction>(Pred->getOperand(0));
4134       }
4135       Inst->setOperand(i, NewOp);
4136 
4137       // Remove Pred if it is not needed anymore.
4138       // Do the same for FullEM.
4139       // GetEM that was used here will be handled later.
4140       if (Pred->use_empty()) {
4141         Pred->eraseFromParent();
4142       }
4143       if (FullEM && FullEM->use_empty()) {
4144         FullEM->eraseFromParent();
4145       }
4146     } else {
4147       // Replace with lowered EM
4148       auto it = LoweredEMValsMap.find(JPData.getRealEM());
4149       IGC_ASSERT(it != LoweredEMValsMap.end() && "Should be checked earlier");
4150       Instruction *LoweredEM = cast<Instruction>(it->second);
4151       Inst->setOperand(i, LoweredEM);
4152 
4153       if (Pred->use_empty())
4154         Pred->eraseFromParent();
4155     }
4156   }
4157 }
4158 
4159 /***********************************************************************
4160  * optimizeLinearization : optimize linearized fragment
4161  *
4162  * This optimization restores SIMD CF for linearized fragment.
4163  * To detect code that can be moved under SIMD CF, we need to find
4164  * a the following select inst:
4165  *    Val = select CondEV, OldVal, NewVal
4166  * Details can be found below.
4167  */
optimizeLinearization(BasicBlock * BB,JoinPointOptData & JPData)4168 void GenXSimdCFConformance::optimizeLinearization(BasicBlock *BB,
4169                                                   JoinPointOptData &JPData) {
4170   std::set<Instruction *> InstsToMove;
4171   std::vector<SelectInst *> SelectsToOptimize;
4172   for (Instruction *Inst = BB->getTerminator()->getPrevNode();
4173        Inst && !dyn_cast<PHINode>(Inst); Inst = Inst->getPrevNode()) {
4174     // We are looking for "Val = select CondEV, OldVal, NewVal" instruction.
4175     //
4176     // Linearization put NewVal calculations after JP. OldVal came from True BB
4177     // via PHI instruction. We can move NewVal calculations under SIMD CF and
4178     // place a PHINode instead of this select. Val, OldVal and NewVal will be
4179     // coalesced and allocated on the same register later.
4180     SelectInst *Select = dyn_cast<SelectInst>(Inst);
4181     if (!Select || !isSelectConditionCondEV(Select, JPData))
4182       continue;
4183 
4184     // Check if OldVal came from outside. Also it can be a constant.
4185     // TODO: current assumption is that nothing except linearization was
4186     // applied. It is possible that OldVal was moved down after it. We also can
4187     // move it back but some analysis is required to avoid possible overhead.
4188     // Not done now.
4189     Value *OldVal = Select->getTrueValue();
4190     if (Instruction *OldValInst = dyn_cast<Instruction>(OldVal)) {
4191       DominatorTree *DomTree = getDomTree(BB->getParent());
4192       // Must dominate this BB and SIMD BB
4193       if (!DomTree->dominates(OldValInst, BB) ||
4194           !DomTree->dominates(OldValInst, JPData.getFalsePred()))
4195         continue;
4196     }
4197 
4198     // Check NewVal
4199     Value *NewVal = Select->getFalseValue();
4200     std::set<Instruction *> Visited;
4201     if (!canBeMovedUnderSIMDCF(NewVal, BB, JPData, Visited))
4202       continue;
4203 
4204     // We can optimize this select
4205     InstsToMove.insert(Visited.begin(), Visited.end());
4206     SelectsToOptimize.push_back(Select);
4207   }
4208 
4209   // Move instructions
4210   // FIXME: there must be a way to do it in a better manner
4211   // The idea of this is to save the instructions' order so we don't brake
4212   // dominance when movement is performed.
4213   std::vector<Instruction *> OrderedInstsToMove;
4214   for (Instruction *Inst = BB->getFirstNonPHI(); Inst;
4215        Inst = Inst->getNextNode()) {
4216     if (InstsToMove.find(Inst) == InstsToMove.end())
4217       continue;
4218     OrderedInstsToMove.push_back(Inst);
4219   }
4220   for (auto *Inst : OrderedInstsToMove) {
4221     replaceGetEMUse(Inst, JPData);
4222     Inst->moveBefore(JPData.getFalsePred()->getTerminator());
4223   }
4224 
4225   // Handle selects
4226   for (auto *Select : SelectsToOptimize) {
4227     PHINode *PN = PHINode::Create(Select->getType(), 2, "optimized_sel",
4228                                   BB->getFirstNonPHI());
4229     PN->addIncoming(Select->getTrueValue(), JPData.getTruePred());
4230     PN->addIncoming(Select->getFalseValue(), JPData.getFalsePred());
4231     Select->replaceAllUsesWith(PN);
4232     Select->eraseFromParent();
4233   }
4234 }
4235 
4236 /***********************************************************************
4237  * GotoJoinEVs::GotoJoinEVs : collects and handle EVs. See CollectEVs
4238  * for more info.
4239  */
GotoJoinEVs(Value * GJ)4240 GenXSimdCFConformance::GotoJoinEVs::GotoJoinEVs(Value* GJ) {
4241   GotoJoin = GJ;
4242 
4243   if (!GotoJoin)
4244     return;
4245 
4246   switch (GenXIntrinsic::getGenXIntrinsicID(GotoJoin)) {
4247   case GenXIntrinsic::genx_simdcf_goto:
4248     IsGoto = true;
4249     break;
4250   case GenXIntrinsic::genx_simdcf_join:
4251     IsGoto = false;
4252     break;
4253   default:
4254     IGC_ASSERT_MESSAGE(0, "Expected goto or join!");
4255     break;
4256   }
4257 
4258   CollectEVs();
4259 }
4260 
4261 /***********************************************************************
4262  * GotoJoinEVs::getEMEV : get EV for goto/join Execution Mask
4263  */
getEMEV() const4264 ExtractValueInst *GenXSimdCFConformance::GotoJoinEVs::getEMEV() const {
4265   IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
4266   static_assert(EMPos < (sizeof(EVs) / sizeof(*EVs)));
4267   return EVs[EMPos];
4268 }
4269 
4270 /***********************************************************************
4271  * GotoJoinEVs::getRMEV : get EV for goto/join Resume Mask
4272  */
getRMEV() const4273 ExtractValueInst *GenXSimdCFConformance::GotoJoinEVs::getRMEV() const {
4274   IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
4275   IGC_ASSERT_MESSAGE(IsGoto, "Only goto returns RM!");
4276   static_assert(RMPos < (sizeof(EVs) / sizeof(*EVs)));
4277   return EVs[RMPos];
4278 }
4279 
4280 /***********************************************************************
4281  * GotoJoinEVs::getCondEV : get EV for goto/join condition
4282  */
getCondEV() const4283 ExtractValueInst *GenXSimdCFConformance::GotoJoinEVs::getCondEV() const {
4284   ExtractValueInst *Result = nullptr;
4285   IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
4286   if (IsGoto) {
4287     static_assert(GotoCondPos < (sizeof(EVs) / sizeof(*EVs)));
4288     Result = EVs[GotoCondPos];
4289   } else {
4290     static_assert(JoinCondPos < (sizeof(EVs) / sizeof(*EVs)));
4291     Result = EVs[JoinCondPos];
4292   }
4293   return Result;
4294 }
4295 
getGotoJoin() const4296 Value *GenXSimdCFConformance::GotoJoinEVs::getGotoJoin() const {
4297   IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
4298   return GotoJoin;
4299 }
4300 
4301 /***********************************************************************
4302  * GotoJoinEVs::getSplitPoint : find first instruction that is not
4303  * a EV or doesn't use Goto/Join. Such instruction always exists
4304  * in a correct IR - BB terminator is a such instruction.
4305  */
getSplitPoint() const4306  Instruction *GenXSimdCFConformance::GotoJoinEVs::getSplitPoint() const {
4307   IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
4308   Instruction *SplitPoint = cast<Instruction>(GotoJoin)->getNextNode();
4309   for (; isa<ExtractValueInst>(SplitPoint) && SplitPoint->getOperand(0) == GotoJoin;
4310     SplitPoint = SplitPoint->getNextNode());
4311   return SplitPoint;
4312  }
4313 
4314 /***********************************************************************
4315  * GotoJoinEVs::setCondEV : set EV for goto/join condition. It is
4316  * needed on basic block splitting to handle bad Cond EV user.
4317  */
setCondEV(ExtractValueInst * CondEV)4318 void GenXSimdCFConformance::GotoJoinEVs::setCondEV(ExtractValueInst *CondEV) {
4319   IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
4320   IGC_ASSERT_MESSAGE(!getCondEV(), "CondEV is already set!");
4321   if (IsGoto) {
4322     static_assert(GotoCondPos < (sizeof(EVs) / sizeof(*EVs)));
4323     EVs[GotoCondPos] = CondEV;
4324   } else {
4325     static_assert(JoinCondPos < (sizeof(EVs) / sizeof(*EVs)));
4326     EVs[JoinCondPos] = CondEV;
4327   }
4328 }
4329 
4330 /***********************************************************************
4331  * GotoJoinEVs::isGoto : check wether this EVs info belongs to goto
4332  */
isGoto() const4333 bool GenXSimdCFConformance::GotoJoinEVs::isGoto() const {
4334   IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
4335   return IsGoto;
4336 }
4337 
4338 /***********************************************************************
4339  * GotoJoinEVs::isJoin : check wether this EVs info belongs to join
4340  */
isJoin() const4341 bool GenXSimdCFConformance::GotoJoinEVs::isJoin() const {
4342   IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
4343   return !IsGoto;
4344 }
4345 
4346 /***********************************************************************
4347  * GotoJoindEVs::CollectEVs : handle and store goto/join EVs
4348  *
4349  * This does the following steps:
4350  *  - Locate EVs. If we found a duplicate, just replace users.
4351  *  - Move EVs right after the goto/join
4352  *  - Add missing EM and RM. This is needed for correct liverange
4353  *    interference analysis.
4354  */
CollectEVs()4355 void GenXSimdCFConformance::GotoJoinEVs::CollectEVs() {
4356   IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
4357   IGC_ASSERT_MESSAGE(testIsGotoJoin(GotoJoin), "Expected goto or join!");
4358 
4359   auto GotoJoinInst = dyn_cast<Instruction>(GotoJoin);
4360 
4361   // Collect EVs, hoist them, resolve duplications
4362   for (auto ui = GotoJoin->use_begin(), ue = GotoJoin->use_end(); ui != ue;) {
4363 
4364     auto EV = dyn_cast<ExtractValueInst>(ui->getUser());
4365     ++ui;
4366 
4367     IGC_ASSERT_MESSAGE(EV, "Bad user of goto/join!");
4368     IGC_ASSERT_MESSAGE(EV->getNumIndices() == 1, "Expected 1 index in Extract Value for goto/join!");
4369 
4370     const unsigned idx = EV->getIndices()[0];
4371     IGC_ASSERT(testPosCorrectness(idx));
4372 
4373     LLVM_DEBUG(dbgs() << "Found EV:\n" << *EV << "\n");
4374     IGC_ASSERT(idx < (sizeof(EVs) / sizeof(*EVs)));
4375     if (EVs[idx]) {
4376       LLVM_DEBUG(dbgs() << "Duplication: replacing users with:\n" << *EVs[idx] << "\n");
4377       EV->replaceAllUsesWith(EVs[idx]);
4378       EV->eraseFromParent();
4379     }
4380     else {
4381       LLVM_DEBUG(dbgs() << "Saving it.\n");
4382       EVs[idx] = EV;
4383     }
4384   }
4385 
4386   // Add missing EVs for masks
4387   for (unsigned idx = 0, end = IsGoto ? RMPos : EMPos; idx <= end; ++idx) {
4388     IGC_ASSERT(idx < (sizeof(EVs) / sizeof(*EVs)));
4389     if (EVs[idx])
4390       continue;
4391 
4392     std::string Name = "missing";
4393     switch (idx) {
4394     case EMPos:
4395       Name += "EMEV";
4396       break;
4397     case RMPos:
4398       Name += "RMEV";
4399       break;
4400     case GotoCondPos:
4401       Name += "CondEV";
4402       break;
4403     }
4404 
4405     auto EV = ExtractValueInst::Create(GotoJoin, { idx }, Name, GotoJoinInst->getParent());
4406     EVs[idx] = EV;
4407   }
4408 
4409   hoistEVs();
4410 }
4411 
4412 /***********************************************************************
4413  * GotoJoinEVs::hoistEVs : move EVs right after goto/join
4414  */
hoistEVs() const4415 void GenXSimdCFConformance::GotoJoinEVs::hoistEVs() const{
4416   IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
4417 
4418   LLVM_DEBUG(dbgs() << "Moving EV users after:\n" << *GotoJoin << "\n");
4419 
4420   const size_t count = (sizeof(EVs) / sizeof(*EVs));
4421   for (size_t idx = 0; idx < count; ++idx) {
4422     if (EVs[idx])
4423       EVs[idx]->moveAfter(dyn_cast<Instruction>(GotoJoin));
4424   }
4425 }
4426 
4427 /***********************************************************************
4428  * DiagnosticInfoSimdCF::emit : emit an error or warning
4429  */
emit(Instruction * Inst,StringRef Msg,DiagnosticSeverity Severity)4430 void DiagnosticInfoSimdCF::emit(Instruction *Inst, StringRef Msg,
4431         DiagnosticSeverity Severity)
4432 {
4433   DiagnosticInfoSimdCF Err(Severity, *Inst->getParent()->getParent(),
4434       Inst->getDebugLoc(), Msg);
4435   Inst->getContext().diagnose(Err);
4436 }
4437 
4438