1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2017-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 //
10 // Utility functions for the GenX backend.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "GenXUtil.h"
14 #include "FunctionGroup.h"
15 #include "GenX.h"
16 #include "GenXIntrinsics.h"
17 
18 #include "vc/GenXOpts/Utils/InternalMetadata.h"
19 #include "vc/Utils/GenX/Printf.h"
20 #include "vc/Utils/General/Types.h"
21 
22 #include "llvm/ADT/MapVector.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/GenXIntrinsics/GenXIntrinsics.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/Metadata.h"
35 #include "llvm/IR/Module.h"
36 
37 #include "llvmWrapper/IR/DerivedTypes.h"
38 #include "llvmWrapper/IR/InstrTypes.h"
39 #include "llvmWrapper/IR/Instructions.h"
40 
41 #include "Probe/Assertion.h"
42 #include <cstddef>
43 #include <iterator>
44 
45 using namespace llvm;
46 using namespace genx;
47 
48 namespace {
49 struct InstScanner {
50   Instruction *Original;
51   Instruction *Current;
InstScanner__anonc37a58090111::InstScanner52   InstScanner(Instruction *Inst) : Original(Inst), Current(Inst) {}
53 };
54 
55 } // namespace
56 
57 /***********************************************************************
58  * createConvert : create a genx_convert intrinsic call
59  *
60  * Enter:   In = value to convert
61  *          Name = name to give convert instruction
62  *          InsertBefore = instruction to insert before else 0
63  *          M = Module (can be 0 as long as InsertBefore is not 0)
64  */
createConvert(Value * In,const Twine & Name,Instruction * InsertBefore,Module * M)65 CallInst *genx::createConvert(Value *In, const Twine &Name,
66     Instruction *InsertBefore, Module *M)
67 {
68   if (!M)
69     M = InsertBefore->getParent()->getParent()->getParent();
70   Function *Decl = GenXIntrinsic::getGenXDeclaration(M, GenXIntrinsic::genx_convert,
71       In->getType());
72   return CallInst::Create(Decl, In, Name, InsertBefore);
73 }
74 
75 /***********************************************************************
76  * createConvertAddr : create a genx_convert_addr intrinsic call
77  *
78  * Enter:   In = value to convert
79  *          Offset = constant offset
80  *          Name = name to give convert instruction
81  *          InsertBefore = instruction to insert before else 0
82  *          M = Module (can be 0 as long as InsertBefore is not 0)
83  */
createConvertAddr(Value * In,int Offset,const Twine & Name,Instruction * InsertBefore,Module * M)84 CallInst *genx::createConvertAddr(Value *In, int Offset, const Twine &Name,
85     Instruction *InsertBefore, Module *M)
86 {
87   if (!M)
88     M = InsertBefore->getParent()->getParent()->getParent();
89   auto OffsetVal = ConstantInt::get(In->getType()->getScalarType(), Offset);
90   Function *Decl = GenXIntrinsic::getGenXDeclaration(M, GenXIntrinsic::genx_convert_addr,
91       In->getType());
92   Value *Args[] = { In, OffsetVal };
93   return CallInst::Create(Decl, Args, Name, InsertBefore);
94 }
95 
96 /***********************************************************************
97  * createAddAddr : create a genx_add_addr intrinsic call
98  *
99  * InsertBefore can be 0 so the new instruction is not inserted anywhere,
100  * but in that case M must be non-0 and set to the Module.
101  */
createAddAddr(Value * Lhs,Value * Rhs,const Twine & Name,Instruction * InsertBefore,Module * M)102 CallInst *genx::createAddAddr(Value *Lhs, Value *Rhs, const Twine &Name,
103     Instruction *InsertBefore, Module *M)
104 {
105   if (!M)
106     M = InsertBefore->getParent()->getParent()->getParent();
107   Value *Args[] = {Lhs, Rhs};
108   Type *Tys[] = {Rhs->getType(), Lhs->getType()};
109   Function *Decl = GenXIntrinsic::getGenXDeclaration(M, GenXIntrinsic::genx_add_addr, Tys);
110   return CallInst::Create(Decl, Args, Name, InsertBefore);
111 }
112 
113 /***********************************************************************
114  * createUnifiedRet : create a dummy instruction that produces dummy
115  * unified return value.
116  *
117  * %Name.unifiedret = call Ty @llvm.ssa_copy(Ty undef)
118  */
createUnifiedRet(Type * Ty,const Twine & Name,Module * M)119 CallInst *genx::createUnifiedRet(Type *Ty, const Twine &Name, Module *M) {
120   IGC_ASSERT_MESSAGE(Ty, "wrong argument");
121   IGC_ASSERT_MESSAGE(M, "wrong argument");
122   auto G = Intrinsic::getDeclaration(M, Intrinsic::ssa_copy, Ty);
123   return CallInst::Create(G, UndefValue::get(Ty), Name + ".unifiedret",
124                           static_cast<Instruction *>(nullptr));
125 }
126 
127 /***********************************************************************
128  * getPredicateConstantAsInt : get an i1 or vXi1 constant's value as a single
129  * integer
130  *
131  * Elements of constant \p C are encoded as least significant bits of the
132  * result. For scalar case only LSB of the result is set to corresponding value.
133  */
getPredicateConstantAsInt(const Constant * C)134 unsigned genx::getPredicateConstantAsInt(const Constant *C) {
135   IGC_ASSERT_MESSAGE(C->getType()->isIntOrIntVectorTy(1),
136     "wrong argument: constant of i1 or Nxi1 type was expected");
137   if (auto CI = dyn_cast<ConstantInt>(C))
138     return CI->getZExtValue(); // scalar
139   unsigned Bits = 0;
140   unsigned NumElements =
141       cast<IGCLLVM::FixedVectorType>(C->getType())->getNumElements();
142   IGC_ASSERT_MESSAGE(NumElements <= sizeof(Bits) * CHAR_BIT,
143     "vector has too much elements, it won't fit into Bits");
144   for (unsigned i = 0; i != NumElements; ++i) {
145     auto El = C->getAggregateElement(i);
146     if (!isa<UndefValue>(El))
147       Bits |= (cast<ConstantInt>(El)->getZExtValue() & 1) << i;
148   }
149   return Bits;
150 }
151 
152 /***********************************************************************
153  * getConstantSubvector : get a contiguous region from a vector constant
154  */
getConstantSubvector(const Constant * V,unsigned StartIdx,unsigned Size)155 Constant *genx::getConstantSubvector(const Constant *V, unsigned StartIdx,
156                                      unsigned Size) {
157   Type *ElTy = cast<VectorType>(V->getType())->getElementType();
158   Type *RegionTy = IGCLLVM::FixedVectorType::get(ElTy, Size);
159   Constant *SubVec = nullptr;
160   if (isa<UndefValue>(V))
161     SubVec = UndefValue::get(RegionTy);
162   else if (isa<ConstantAggregateZero>(V))
163     SubVec = ConstantAggregateZero::get(RegionTy);
164   else {
165     SmallVector<Constant *, 32> Val;
166     for (unsigned i = 0; i != Size; ++i)
167       Val.push_back(V->getAggregateElement(i + StartIdx));
168     SubVec = ConstantVector::get(Val);
169   }
170   return SubVec;
171 }
172 
173 /***********************************************************************
174  * concatConstants : concatenate two possibly vector constants, giving a
175  *      vector constant
176  */
concatConstants(Constant * C1,Constant * C2)177 Constant *genx::concatConstants(Constant *C1, Constant *C2)
178 {
179   IGC_ASSERT(C1->getType()->getScalarType() == C2->getType()->getScalarType());
180   Constant *CC[] = { C1, C2 };
181   SmallVector<Constant *, 8> Vec;
182   bool AllUndef = true;
183   for (unsigned Idx = 0; Idx != 2; ++Idx) {
184     Constant *C = CC[Idx];
185     if (auto *VT = dyn_cast<IGCLLVM::FixedVectorType>(C->getType())) {
186       for (unsigned i = 0, e = VT->getNumElements(); i != e; ++i) {
187         Constant *El = C->getAggregateElement(i);
188         Vec.push_back(El);
189         AllUndef &= isa<UndefValue>(El);
190       }
191     } else {
192       Vec.push_back(C);
193       AllUndef &= isa<UndefValue>(C);
194     }
195   }
196   auto Res = ConstantVector::get(Vec);
197   if (AllUndef)
198     Res = UndefValue::get(Res->getType());
199   return Res;
200 }
201 
202 /***********************************************************************
203  * findClosestCommonDominator : find closest common dominator of some instructions
204  *
205  * Enter:   DT = dominator tree
206  *          Insts = the instructions
207  *
208  * Return:  The one instruction that dominates all the others, if any.
209  *          Otherwise the terminator of the closest common dominating basic
210  *          block.
211  */
findClosestCommonDominator(DominatorTree * DT,ArrayRef<Instruction * > Insts)212 Instruction *genx::findClosestCommonDominator(DominatorTree *DT,
213     ArrayRef<Instruction *> Insts)
214 {
215   IGC_ASSERT(!Insts.empty());
216   SmallVector<InstScanner, 8> InstScanners;
217   // Find the closest common dominating basic block.
218   Instruction *Inst0 = Insts[0];
219   BasicBlock *NCD = Inst0->getParent();
220   InstScanners.push_back(InstScanner(Inst0));
221   for (unsigned ii = 1, ie = Insts.size(); ii != ie; ++ii) {
222     Instruction *Inst = Insts[ii];
223     if (Inst->getParent() != NCD) {
224       auto NewNCD = DT->findNearestCommonDominator(NCD, Inst->getParent());
225       if (NewNCD != NCD)
226         InstScanners.clear();
227       NCD = NewNCD;
228     }
229     if (NCD == Inst->getParent())
230       InstScanners.push_back(Inst);
231   }
232   // Now we have NCD = the closest common dominating basic block, and
233   // InstScanners populated with the instructions from Insts that are
234   // in that block.
235   if (InstScanners.empty()) {
236     // No instructions in that block. Return the block's terminator.
237     return NCD->getTerminator();
238   }
239   if (InstScanners.size() == 1) {
240     // Only one instruction in that block. Return it.
241     return InstScanners[0].Original;
242   }
243   // Create a set of the original instructions.
244   std::set<Instruction *> OrigInsts;
245   for (auto i = InstScanners.begin(), e = InstScanners.end(); i != e; ++i)
246     OrigInsts.insert(i->Original);
247   // Scan back one instruction at a time for each scanner. If a scanner reaches
248   // another original instruction, the scanner can be removed, and when we are
249   // left with one scanner, that must be the earliest of the original
250   // instructions.  If a scanner reaches the start of the basic block, that was
251   // the earliest of the original instructions.
252   //
253   // In the worst case, this algorithm could scan all the instructions in a
254   // basic block, but it is designed to be better than that in the common case
255   // that the original instructions are close to each other.
256   for (;;) {
257     for (auto i = InstScanners.begin(), e = InstScanners.end(); i != e; ++i) {
258       if (i->Current == &i->Current->getParent()->front())
259         return i->Original; // reached start of basic block
260       i->Current = i->Current->getPrevNode();
261       if (OrigInsts.find(i->Current) != OrigInsts.end()) {
262         // Scanned back to another instruction in our original set. Remove
263         // this scanner.
264         *i = InstScanners.back();
265         InstScanners.pop_back();
266         if (InstScanners.size() == 1)
267           return InstScanners[0].Original; // only one scanner left
268         break; // restart loop so as not to confuse the iterator
269       }
270     }
271   }
272 }
273 
274 /***********************************************************************
275  * getTwoAddressOperandNum : get operand number of two address operand
276  *
277  * If an intrinsic has a "two address operand", then that operand must be
278  * in the same register as the result. This function returns the operand number
279  * of the two address operand if any, or None if not.
280  */
getTwoAddressOperandNum(CallInst * CI)281 llvm::Optional<unsigned> genx::getTwoAddressOperandNum(CallInst *CI)
282 {
283   auto IntrinsicID = GenXIntrinsic::getAnyIntrinsicID(CI);
284   if (IntrinsicID == GenXIntrinsic::not_any_intrinsic)
285     return None; // not intrinsic
286   // wr(pred(pred))region has operand 0 as two address operand
287   if (GenXIntrinsic::isWrRegion(IntrinsicID) ||
288       IntrinsicID == GenXIntrinsic::genx_wrpredregion ||
289       IntrinsicID == GenXIntrinsic::genx_wrpredpredregion)
290     return GenXIntrinsic::GenXRegion::OldValueOperandNum;
291   if (CI->getType()->isVoidTy())
292     return None; // no return value
293   GenXIntrinsicInfo II(IntrinsicID);
294   unsigned Num = CI->getNumArgOperands();
295   if (!Num)
296     return None; // no args
297   --Num; // Num = last arg number, could be two address operand
298   if (isa<UndefValue>(CI->getOperand(Num)))
299     return None; // operand is undef, must be RAW_NULLALLOWED
300   if (II.getArgInfo(Num).getCategory() != GenXIntrinsicInfo::TWOADDR)
301     return None; // not two addr operand
302   if (CI->use_empty() && II.getRetInfo().rawNullAllowed())
303     return None; // unused result will be V0
304   return Num; // it is two addr
305 }
306 
307 /***********************************************************************
308  * isNot : test whether an instruction is a "not" instruction (an xor with
309  *    constant all ones)
310  */
isNot(Instruction * Inst)311 bool genx::isNot(Instruction *Inst)
312 {
313   if (Inst->getOpcode() == Instruction::Xor)
314     if (auto C = dyn_cast<Constant>(Inst->getOperand(1)))
315       if (C->isAllOnesValue())
316         return true;
317   return false;
318 }
319 
320 /***********************************************************************
321  * isPredNot : test whether an instruction is a "not" instruction (an xor
322  *    with constant all ones) with predicate (i1 or vector of i1) type
323  */
isPredNot(Instruction * Inst)324 bool genx::isPredNot(Instruction *Inst)
325 {
326   if (Inst->getOpcode() == Instruction::Xor)
327     if (auto C = dyn_cast<Constant>(Inst->getOperand(1)))
328       if (C->isAllOnesValue() && C->getType()->getScalarType()->isIntegerTy(1))
329         return true;
330   return false;
331 }
332 
333 /***********************************************************************
334  * isIntNot : test whether an instruction is a "not" instruction (an xor
335  *    with constant all ones) with non-predicate type
336  */
isIntNot(Instruction * Inst)337 bool genx::isIntNot(Instruction *Inst)
338 {
339   if (Inst->getOpcode() == Instruction::Xor)
340     if (auto C = dyn_cast<Constant>(Inst->getOperand(1)))
341       if (C->isAllOnesValue() && !C->getType()->getScalarType()->isIntegerTy(1))
342         return true;
343   return false;
344 }
345 
346 /***********************************************************************
347  * invertCondition : Invert the given predicate value, possibly reusing
348  * an existing copy.
349  */
invertCondition(Value * Condition)350 Value *genx::invertCondition(Value *Condition)
351 {
352   IGC_ASSERT_MESSAGE(Condition->getType()->getScalarType()->isIntegerTy(1),
353     "Condition is not of predicate type");
354   // First: Check if it's a constant.
355   if (Constant *C = dyn_cast<Constant>(Condition))
356     return ConstantExpr::getNot(C);
357 
358   // Second: If the condition is already inverted, return the original value.
359   Instruction *Inst = dyn_cast<Instruction>(Condition);
360   if (Inst && isPredNot(Inst))
361     return Inst->getOperand(0);
362 
363   // Last option: Create a new instruction.
364   auto *Inverted =
365       BinaryOperator::CreateNot(Condition, Condition->getName() + ".inv");
366   if (Inst && !isa<PHINode>(Inst))
367     Inverted->insertAfter(Inst);
368   else {
369     BasicBlock *Parent = nullptr;
370     if (Inst)
371       Parent = Inst->getParent();
372     else if (Argument *Arg = dyn_cast<Argument>(Condition))
373       Parent = &Arg->getParent()->getEntryBlock();
374     IGC_ASSERT_MESSAGE(Parent, "Unsupported condition to invert");
375     Inverted->insertBefore(&*Parent->getFirstInsertionPt());
376   }
377   return Inverted;
378 }
379 
380 /***********************************************************************
381  * isNoopCast : test if cast operation doesn't modify bitwise representation
382  * of value (in other words, it can be copy-coalesced).
383  * NOTE: LLVM has CastInst::isNoopCast method, but it conservatively treats
384  * AddrSpaceCast as modifying operation; this function can be more aggresive
385  * relying on DataLayout information.
386  */
isNoopCast(const CastInst * CI)387 bool genx::isNoopCast(const CastInst *CI) {
388   const DataLayout &DL = CI->getModule()->getDataLayout();
389   switch (CI->getOpcode()) {
390   case Instruction::BitCast:
391     return true;
392   case Instruction::PtrToInt:
393   case Instruction::IntToPtr:
394   case Instruction::AddrSpaceCast:
395     return vc::getTypeSize(CI->getDestTy(), &DL) ==
396            vc::getTypeSize(CI->getSrcTy(), &DL);
397   default:
398     return false;
399   }
400 }
401 
402 /***********************************************************************
403  * ShuffleVectorAnalyzer::getAsSlice : see if the shufflevector is a slice on
404  *    operand 0, and if so return the start index, or -1 if it is not a slice
405  */
getAsSlice()406 int ShuffleVectorAnalyzer::getAsSlice()
407 {
408   unsigned WholeWidth =
409       cast<IGCLLVM::FixedVectorType>(SI->getOperand(0)->getType())
410           ->getNumElements();
411   Constant *Selector = IGCLLVM::getShuffleMaskForBitcode(SI);
412   unsigned Width =
413       cast<IGCLLVM::FixedVectorType>(SI->getType())->getNumElements();
414   auto *Aggr = Selector->getAggregateElement(0u);
415   if (isa<UndefValue>(Aggr))
416     return -1; // operand 0 is undef value
417   unsigned StartIdx = cast<ConstantInt>(Aggr)->getZExtValue();
418   if (StartIdx >= WholeWidth)
419     return -1; // start index beyond operand 0
420   unsigned SliceWidth;
421   for (SliceWidth = 1; SliceWidth != Width; ++SliceWidth) {
422     auto CI = dyn_cast<ConstantInt>(Selector->getAggregateElement(SliceWidth));
423     if (!CI)
424       break;
425     if (CI->getZExtValue() != StartIdx + SliceWidth)
426       return -1; // not slice
427   }
428   return StartIdx;
429 }
430 
431 /***********************************************************************
432  * ShuffleVectorAnalyzer::isReplicatedSlice : check if the shufflevector
433  * is a replicated slice on operand 0.
434  */
isReplicatedSlice() const435 bool ShuffleVectorAnalyzer::isReplicatedSlice() const {
436   const auto MaskVals = SI->getShuffleMask();
437   auto Begin = MaskVals.begin();
438   auto End = MaskVals.end();
439 
440   // Check for undefs.
441   if (std::find(Begin, End, -1) != End)
442     return false;
443 
444   if (MaskVals.size() == 1)
445     return true;
446 
447   // Slice should not touch second operand.
448   auto MaxIndex = static_cast<size_t>(MaskVals.back());
449   if (MaxIndex >= cast<IGCLLVM::FixedVectorType>(SI->getOperand(0)->getType())
450                       ->getNumElements())
451     return false;
452 
453   // Find first non-one difference.
454   auto SliceEnd =
455       std::adjacent_find(Begin, End,
456                          [](int Prev, int Next) { return Next - Prev != 1; });
457   // If not found, then it is simple slice.
458   if (SliceEnd == End)
459     return true;
460 
461   // Compare slice with parts of sequence to prove that it is periodic.
462   ++SliceEnd;
463   unsigned SliceSize = std::distance(Begin, SliceEnd);
464   // Slice should be replicated.
465   if (MaskVals.size() % SliceSize != 0)
466     return false;
467 
468   for (auto It = SliceEnd; It != End; std::advance(It, SliceSize))
469     if (!std::equal(Begin, SliceEnd, It))
470       return false;
471 
472   return true;
473 }
474 
getMaskOperand(const Instruction * Inst)475 Value *genx::getMaskOperand(const Instruction *Inst) {
476   IGC_ASSERT(Inst);
477 
478   // return null for any other intrusction except
479   // genx intrinsics
480   auto *CI = dyn_cast<CallInst>(Inst);
481   if (!CI || !GenXIntrinsic::isGenXIntrinsic(CI))
482     return nullptr;
483 
484   auto MaskOpIt = llvm::find_if(CI->operands(), [](const Use &U) {
485     Value *Operand = U.get();
486     if (auto *VT = dyn_cast<VectorType>(Operand->getType()))
487       return VT->getElementType()->isIntegerTy(1);
488     return false;
489   });
490 
491   // No mask among opernads
492   if (MaskOpIt == CI->op_end())
493     return nullptr;
494 
495   return *MaskOpIt;
496 }
497 
498 // Based on the value of a shufflevector mask element defines in which of
499 // 2 operands it points. The operand is returned.
getOperandByMaskValue(const ShuffleVectorInst & SI,int MaskValue)500 static Value *getOperandByMaskValue(const ShuffleVectorInst &SI,
501                                     int MaskValue) {
502   IGC_ASSERT_MESSAGE(MaskValue >= 0, "invalid index");
503   int FirstOpSize = cast<IGCLLVM::FixedVectorType>(SI.getOperand(0)->getType())
504                         ->getNumElements();
505   if (MaskValue < FirstOpSize)
506     return SI.getOperand(0);
507   else {
508     int SecondOpSize =
509         cast<IGCLLVM::FixedVectorType>(SI.getOperand(1)->getType())
510             ->getNumElements();
511     IGC_ASSERT_MESSAGE(MaskValue < FirstOpSize + SecondOpSize, "invalid index");
512     return SI.getOperand(1);
513   }
514 }
515 
516 // safe advance
517 // If adding \p N results in bound violation, \p Last is written to \p It
advanceSafe(Iter & It,Iter Last,int N)518 template <typename Iter> void advanceSafe(Iter &It, Iter Last, int N) {
519   if (N > std::distance(It, Last)) {
520     It = Last;
521     return;
522   }
523   std::advance(It, N);
524 }
525 
526 // Returns operand and its region of 1 element that is referenced by
527 // \p MaskVal element of shufflevector mask.
528 static ShuffleVectorAnalyzer::OperandRegionInfo
matchOneElemRegion(const ShuffleVectorInst & SI,int MaskVal)529 matchOneElemRegion(const ShuffleVectorInst &SI, int MaskVal) {
530   ShuffleVectorAnalyzer::OperandRegionInfo Init;
531   Init.Op = getOperandByMaskValue(SI, MaskVal);
532   Init.R = Region(Init.Op);
533   Init.R.NumElements = Init.R.Width = 1;
534   if (Init.Op == SI.getOperand(0))
535     Init.R.Offset = MaskVal * Init.R.ElementBytes;
536   else {
537     auto FirstOpSize =
538         cast<IGCLLVM::FixedVectorType>(SI.getOperand(0)->getType())
539             ->getNumElements();
540     Init.R.Offset = (MaskVal - FirstOpSize) * Init.R.ElementBytes;
541   }
542   return Init;
543 }
544 
545 class MaskIndex {
546   int Idx;
547   static constexpr const int Undef = -1;
548   static constexpr const int AnotherOp = -2;
549 
550 public:
MaskIndex(int InitIdx=0)551   explicit MaskIndex(int InitIdx = 0) : Idx(InitIdx) {
552     IGC_ASSERT_MESSAGE(Idx >= 0, "Defined index must not be negative");
553   }
554 
getUndef()555   static MaskIndex getUndef() {
556     MaskIndex Ret;
557     Ret.Idx = Undef;
558     return Ret;
559   }
560 
getAnotherOp()561   static MaskIndex getAnotherOp() {
562     MaskIndex Ret;
563     Ret.Idx = AnotherOp;
564     return Ret;
565   }
isUndef() const566   bool isUndef() const { return Idx == Undef; }
isAnotherOp() const567   bool isAnotherOp() const { return Idx == AnotherOp; }
isDefined() const568   bool isDefined() const { return Idx >= 0; }
569 
get() const570   int get() const {
571     IGC_ASSERT_MESSAGE(Idx >= 0, "Can't call get() on invalid index");
572     return Idx;
573   }
574 
operator -(MaskIndex const & rhs) const575   int operator-(MaskIndex const &rhs) const {
576     IGC_ASSERT_MESSAGE(isDefined(), "All operand indices must be valid");
577     IGC_ASSERT_MESSAGE(rhs.isDefined(), "All operand indices must be valid");
578     return Idx - rhs.Idx;
579   }
580 };
581 
582 // Takes shufflevector mask indexes from [\p FirstIt, \p LastIt),
583 // converts them to the indices of \p Operand of \p SI instruction
584 // and writes them to \p OutIt. Value type of OutIt is MaskIndex.
585 template <typename ForwardIter, typename OutputIter>
makeSVIIndexesOperandIndexes(const ShuffleVectorInst & SI,const Value & Operand,ForwardIter FirstIt,ForwardIter LastIt,OutputIter OutIt)586 void makeSVIIndexesOperandIndexes(const ShuffleVectorInst &SI,
587                                   const Value &Operand, ForwardIter FirstIt,
588                                   ForwardIter LastIt, OutputIter OutIt) {
589   int FirstOpSize = cast<IGCLLVM::FixedVectorType>(SI.getOperand(0)->getType())
590                         ->getNumElements();
591   if (&Operand == SI.getOperand(0)) {
592     std::transform(FirstIt, LastIt, OutIt, [FirstOpSize](int MaskVal) {
593       if (MaskVal >= FirstOpSize)
594         return MaskIndex::getAnotherOp();
595       return MaskVal < 0 ? MaskIndex::getUndef() : MaskIndex{MaskVal};
596     });
597     return;
598   }
599   IGC_ASSERT_MESSAGE(&Operand == SI.getOperand(1),
600     "wrong argument: a shufflevector operand was expected");
601   std::transform(FirstIt, LastIt, OutIt, [FirstOpSize](int MaskVal) {
602     if (MaskVal < 0)
603       return MaskIndex::getUndef();
604     return MaskVal >= FirstOpSize ? MaskIndex{MaskVal - FirstOpSize}
605                                   : MaskIndex::getAnotherOp();
606   });
607 }
608 
609 // Calculates horisontal stride for region by scanning mask indices in
610 // range [\p FirstIt, \p LastIt).
611 //
612 // Arguments:
613 //    [\p FirstIt, \p LastIt) is the range of MaskIndex. There must not be
614 //    any AnotherOp indices in the range.
615 //    \P FirstIt must point to a defined index.
616 // Return value:
617 //    std::pair with first element to be Iterator to next defined element
618 //    or std::next(FirstIt) if there is no such one and second element to
619 //    be estimated stride if positive and integer, empty value otherwise.
620 template <typename ForwardIter>
621 std::pair<ForwardIter, llvm::Optional<int>>
estimateHorizontalStride(ForwardIter FirstIt,ForwardIter LastIt)622 estimateHorizontalStride(ForwardIter FirstIt, ForwardIter LastIt) {
623 
624   IGC_ASSERT_MESSAGE(FirstIt != LastIt, "the range must contain at least 1 element");
625   IGC_ASSERT_MESSAGE(std::none_of(FirstIt, LastIt, [](MaskIndex Idx) { return Idx.isAnotherOp(); }),
626    "There must not be any AnotherOp indices in the range");
627   IGC_ASSERT_MESSAGE(FirstIt->isDefined(),
628     "first element in range must be a valid index");
629   auto NextDefined =
630       std::find_if(std::next(FirstIt), LastIt,
631                    [](MaskIndex Elem) { return Elem.isDefined(); });
632 
633   if (NextDefined == LastIt)
634     return {std::next(FirstIt), llvm::Optional<int>{}};
635 
636   int TotalStride = *NextDefined - *FirstIt;
637   int TotalWidth = std::distance(FirstIt, NextDefined);
638 
639   if (TotalStride < 0 || (TotalStride % TotalWidth != 0 && TotalStride != 0))
640     return {NextDefined, llvm::Optional<int>{}};
641 
642   return {NextDefined, TotalStride / TotalWidth};
643 }
644 
645 // Matches "vector" region (with vstride == 0) pattern in
646 // [\p FirstIt, \p LastIt) indexes.
647 // Uses info in \p FirstElemRegion, adds defined Width, Stride and
648 // new NumElements to \p FirstElemRegion and returns resulting region.
649 //
650 // Arguments:
651 //    [\p FirstIt, \p LastIt) is the range of MaskIndex. There must not be
652 //    any AnotherOp indices in the range.
653 //    FirstIt and std::prev(LastIt) must point to a defined indices.
654 //    \p FirstElemRegion describes one element region with only one index
655 //    *FirstIt.
656 //    \p BoundIndex is maximum possible index of the input vector + 1
657 //    (BoundIndex == InputVector.length)
658 template <typename ForwardIter>
matchVectorRegionByIndexes(Region FirstElemRegion,ForwardIter FirstIt,ForwardIter LastIt,int BoundIndex)659 Region matchVectorRegionByIndexes(Region FirstElemRegion, ForwardIter FirstIt,
660                                   ForwardIter LastIt, int BoundIndex) {
661   IGC_ASSERT_MESSAGE(FirstIt != LastIt, "the range must contain at least 1 element");
662   IGC_ASSERT_MESSAGE(std::none_of(FirstIt, LastIt, [](MaskIndex Idx) { return Idx.isAnotherOp(); }),
663     "There must not be any AnotherOp indices in the range.");
664   IGC_ASSERT_MESSAGE(FirstIt->isDefined(),
665     "expected FirstIt and --LastIt point to valid indices");
666   IGC_ASSERT_MESSAGE(std::prev(LastIt)->isDefined(),
667     "expected FirstIt and --LastIt point to valid indices");
668 
669   if (std::distance(FirstIt, LastIt) == 1)
670     return FirstElemRegion;
671 
672   llvm::Optional<int> RefStride;
673   ForwardIter NewRowIt;
674   std::tie(NewRowIt, RefStride) = estimateHorizontalStride(FirstIt, LastIt);
675 
676   if (!RefStride)
677     return FirstElemRegion;
678 
679   llvm::Optional<int> Stride = RefStride;
680   while (Stride == RefStride)
681     std::tie(NewRowIt, Stride) = estimateHorizontalStride(NewRowIt, LastIt);
682 
683   auto TotalStride = std::distance(FirstIt, std::prev(NewRowIt)) * *RefStride;
684   auto Overstep = TotalStride + FirstIt->get() - BoundIndex + 1;
685   if (Overstep > 0)
686     NewRowIt = std::prev(NewRowIt, llvm::divideCeil(Overstep, *RefStride));
687 
688   int Width = std::distance(FirstIt, NewRowIt);
689   IGC_ASSERT_MESSAGE(Width > 0, "should be at least 1 according to algorithm");
690   if (Width == 1)
691     // Stride doesn't play role when the Width is 1.
692     // Also it prevents from writing to big value in the region.
693     RefStride = 0;
694   FirstElemRegion.Stride = *RefStride;
695   FirstElemRegion.Width = Width;
696   FirstElemRegion.NumElements = Width;
697   return FirstElemRegion;
698 }
699 
700 // Calculates vertical stride for region by scanning mask indices in
701 // range [\p FirstIt, \p LastIt).
702 //
703 // Arguments:
704 //    [\p FirstRowRegion]  describes "vector" region (with vstride == 0),
705 //      which is formed by first 'FirstRowRegion.NumElements' elements
706 //      of the range.
707 //    [\p FirstIt] Points to first element of vector of indices.
708 //    [\p ReferenceIt] Points to some valid reference element in that vector.
709 //    Must be out of range of first 'FirstRowRegion.NumElements' elements.
710 //    First 'FirstRowRegion.NumElements' in range must be defined indices.
711 // Return value:
712 //    Value of estimated vertical stride if it is positive and integer,
713 //    empty value otherwise
714 template <typename ForwardIter>
estimateVerticalStride(Region FirstRowRegion,ForwardIter FirstIt,ForwardIter ReferenceIt)715 llvm::Optional<int> estimateVerticalStride(Region FirstRowRegion,
716                                            ForwardIter FirstIt,
717                                            ForwardIter ReferenceIt) {
718 
719   IGC_ASSERT_MESSAGE(std::distance(FirstIt, ReferenceIt) >= static_cast<std::ptrdiff_t>(FirstRowRegion.Width),
720     "Reference element must not be part of first row");
721   IGC_ASSERT_MESSAGE(std::all_of(FirstIt, std::next(FirstIt, FirstRowRegion.Width), [](MaskIndex Elem) { return Elem.isDefined(); }),
722     "First row must contain only valid indices");
723   IGC_ASSERT_MESSAGE(ReferenceIt->isDefined(), "Reference index must be valid");
724 
725   int Width = FirstRowRegion.Width;
726 
727   int TotalDistance = std::distance(FirstIt, ReferenceIt);
728   int VStridesToDef = TotalDistance / Width;
729   int HStridesToDef = TotalDistance % Width;
730   int TotalVerticalStride = *ReferenceIt - *std::next(FirstIt, HStridesToDef);
731   if (TotalVerticalStride < 0 || TotalVerticalStride % VStridesToDef != 0)
732     return llvm::Optional<int>{};
733 
734   return llvm::Optional<int>{TotalVerticalStride / VStridesToDef};
735 }
736 
737 // Matches "matrix" region (vstride may not equal to 0) pattern in
738 // [\p FirstIt, \p LastIt) index.
739 // Uses info in \p FirstRowRegion, adds defined VStride and new NumElements to
740 // \p FirstRowRegion and returns resulting region.
741 //
742 // Arguments:
743 //    [\p FirstIt, \p LastIt) is the range of MaskIndex. Note that this
744 //    pass may change the contents of this vector (replace undef indices
745 //    with defined ones), so it can affect further usage.
746 //    \p LastDefinedIt points to last element in a vector with a defined
747 //    index
748 //    \p FirstRowRegion describes "vector" region (with vstride == 0),
749 //      which is formed by first 'FirstRowRegion.NumElements' elements
750 //      of the range.
751 //    \p BoundIndex is maximum possible index of the input vector + 1
752 //    (BoundIndex == InputVector.length)
753 template <typename ForwardIter>
matchMatrixRegionByIndexes(Region FirstRowRegion,ForwardIter FirstIt,ForwardIter LastIt,ForwardIter LastDefinedIt,int BoundIndex)754 Region matchMatrixRegionByIndexes(Region FirstRowRegion, ForwardIter FirstIt,
755                                   ForwardIter LastIt, ForwardIter LastDefinedIt,
756                                   int BoundIndex) {
757   IGC_ASSERT_MESSAGE(FirstRowRegion.NumElements == FirstRowRegion.Width,
758     "wrong argunent: vector region (with no vstride) was expected");
759   IGC_ASSERT_MESSAGE(FirstRowRegion.VStride == 0,
760     "wrong argunent: vector region (with no vstride) was expected");
761   IGC_ASSERT_MESSAGE(FirstIt->isDefined(),
762     "expected FirstIt and LastDefinedIt point to valid indices");
763   IGC_ASSERT_MESSAGE(LastDefinedIt->isDefined(),
764     "expected FirstIt and LastDefinedIt point to valid indices");
765   IGC_ASSERT_MESSAGE(std::distance(FirstIt, LastIt) >= static_cast<std::ptrdiff_t>(FirstRowRegion.Width),
766     "wrong argument: number of indexes must be at least equal to region width");
767 
768   auto FirstRowEndIt = std::next(FirstIt, FirstRowRegion.Width);
769   if (FirstRowEndIt == LastIt)
770     return FirstRowRegion;
771 
772   auto FirstDefined = std::find_if(FirstRowEndIt, LastIt, [](MaskIndex Idx) {
773     return Idx.isDefined() || Idx.isAnotherOp();
774   });
775   if (FirstDefined == LastIt || FirstDefined->isAnotherOp())
776     return FirstRowRegion;
777 
778   int Stride = FirstRowRegion.Stride;
779   int Idx = FirstIt->get();
780   std::generate(std::next(FirstIt), FirstRowEndIt,
781                 [Idx, Stride]() mutable { return MaskIndex{Idx += Stride}; });
782 
783   llvm::Optional<int> VStride =
784       estimateVerticalStride(FirstRowRegion, FirstIt, FirstDefined);
785   if (!VStride)
786     return FirstRowRegion;
787 
788   int VDistance = *VStride;
789   int Width = FirstRowRegion.Width;
790   int NumElements = FirstRowRegion.Width;
791   int HighestFirstRowElement = std::prev(FirstRowEndIt)->get();
792 
793   for (auto It = FirstRowEndIt; It != LastIt; advanceSafe(It, LastIt, Width),
794             NumElements += Width, VDistance += *VStride) {
795     if (It > LastDefinedIt || std::distance(It, LastIt) < Width ||
796         HighestFirstRowElement + VDistance >= BoundIndex ||
797         !std::equal(FirstIt, FirstRowEndIt, It,
798                     [VDistance](MaskIndex Reference, MaskIndex Current) {
799                       return !Current.isAnotherOp() &&
800                              (Current.isUndef() ||
801                               Current.get() - Reference.get() == VDistance);
802                     }))
803       break;
804   }
805 
806   if (NumElements == Width)
807     // VStride doesn't play role when the Width is equal to NumElements.
808     // Also it prevents from writing to big value in the region.
809     VStride = 0;
810   FirstRowRegion.VStride = *VStride;
811   FirstRowRegion.NumElements = NumElements;
812   return FirstRowRegion;
813 }
814 
815 // Analyzes shufflevector mask starting from \p StartIdx element of it.
816 // Finds the longest prefix of the cutted shufflevector mask that can be
817 // represented as a region of one operand of the instruction.
818 // Returns the operand and its region.
819 //
820 // For example:
821 // {0, 1, 3, 4, 25, 16 ...} -> first 4 elements form a region:
822 //                             <3;2,1> vstride=3, width=2, stride=1
823 ShuffleVectorAnalyzer::OperandRegionInfo
getMaskRegionPrefix(int StartIdx)824 ShuffleVectorAnalyzer::getMaskRegionPrefix(int StartIdx) {
825   IGC_ASSERT_MESSAGE(StartIdx >= 0, "Start index is out of bound");
826   IGC_ASSERT_MESSAGE(StartIdx < static_cast<int>(SI->getShuffleMask().size()),
827     "Start index is out of bound");
828 
829   auto MaskVals = SI->getShuffleMask();
830   auto StartIt = std::next(MaskVals.begin(), StartIdx);
831   OperandRegionInfo Res = matchOneElemRegion(*SI, *StartIt);
832 
833   if (StartIdx == MaskVals.size() - 1)
834     return Res;
835 
836   std::vector<MaskIndex> SubMask;
837   makeSVIIndexesOperandIndexes(*SI, *Res.Op, StartIt, MaskVals.end(),
838                                std::back_inserter(SubMask));
839 
840   auto FirstAnotherOpElement =
841       std::find_if(SubMask.begin(), SubMask.end(),
842                    [](MaskIndex Elem) { return Elem.isAnotherOp(); });
843   auto PastLastDefinedElement =
844       std::find_if(std::reverse_iterator(FirstAnotherOpElement),
845                    std::reverse_iterator(SubMask.begin()),
846                    [](MaskIndex Elem) { return Elem.isDefined(); })
847           .base();
848 
849   Res.R = matchVectorRegionByIndexes(
850       std::move(Res.R), SubMask.begin(), PastLastDefinedElement,
851       cast<IGCLLVM::FixedVectorType>(Res.Op->getType())->getNumElements());
852   Res.R = matchMatrixRegionByIndexes(
853       std::move(Res.R), SubMask.begin(), SubMask.end(),
854       std::prev(PastLastDefinedElement),
855       cast<IGCLLVM::FixedVectorType>(Res.Op->getType())->getNumElements());
856   return Res;
857 }
858 
859 /***********************************************************************
860  * ShuffleVectorAnalyzer::getAsUnslice : see if the shufflevector is an
861  *    unslice where the "old value" is operand 0 and operand 1 is another
862  *    shufflevector and operand 0 of that is the "new value"
863  *
864  * Return:  start index, or -1 if it is not an unslice
865  */
getAsUnslice()866 int ShuffleVectorAnalyzer::getAsUnslice()
867 {
868   auto SI2 = dyn_cast<ShuffleVectorInst>(SI->getOperand(1));
869   if (!SI2)
870     return -1;
871   Constant *MaskVec = IGCLLVM::getShuffleMaskForBitcode(SI);
872   // Find prefix of undef or elements from operand 0.
873   unsigned OldWidth =
874       cast<IGCLLVM::FixedVectorType>(SI2->getType())->getNumElements();
875   unsigned NewWidth =
876       cast<IGCLLVM::FixedVectorType>(SI2->getOperand(0)->getType())
877           ->getNumElements();
878   unsigned Prefix = 0;
879   for (;; ++Prefix) {
880     if (Prefix == OldWidth - NewWidth)
881       break;
882     Constant *IdxC = MaskVec->getAggregateElement(Prefix);
883     if (isa<UndefValue>(IdxC))
884       continue;
885     unsigned Idx = cast<ConstantInt>(IdxC)->getZExtValue();
886     if (Idx == OldWidth)
887       break; // found end of prefix
888     if (Idx != Prefix)
889       return -1; // not part of prefix
890   }
891   // Check that the whole of SI2 operand 0 follows
892   for (unsigned i = 1; i != NewWidth; ++i) {
893     Constant *IdxC = MaskVec->getAggregateElement(Prefix + i);
894     if (isa<UndefValue>(IdxC))
895       continue;
896     if (cast<ConstantInt>(IdxC)->getZExtValue() != i + OldWidth)
897       return -1; // not got whole of SI2 operand 0
898   }
899   // Check that the remainder is undef or elements from operand 0.
900   for (unsigned i = Prefix + NewWidth; i != OldWidth; ++i) {
901     Constant *IdxC = MaskVec->getAggregateElement(i);
902     if (isa<UndefValue>(IdxC))
903       continue;
904     if (cast<ConstantInt>(IdxC)->getZExtValue() != i)
905       return -1;
906   }
907   // Check that the first Prefix elements of SI2 come from its operand 1.
908   Constant *MaskVec2 = IGCLLVM::getShuffleMaskForBitcode(SI2);
909   for (unsigned i = 0; i != Prefix; ++i) {
910     Constant *IdxC = MaskVec2->getAggregateElement(Prefix + i);
911     if (isa<UndefValue>(IdxC))
912       continue;
913     if (cast<ConstantInt>(IdxC)->getZExtValue() != i)
914       return -1;
915   }
916   // Success.
917   return Prefix;
918 }
919 
920 #if LLVM_VERSION_MAJOR <= 10
921 constexpr int UndefMaskElem = -1;
922 #endif
923 
924 /***********************************************************************
925  * extension of ShuffleVectorInst::isZeroEltSplatMask method
926  */
nEltSplatMask(ArrayRef<int> Mask)927 static int nEltSplatMask(ArrayRef<int> Mask) {
928   int Elt = UndefMaskElem;
929   for (int i = 0, NumElts = Mask.size(); i < NumElts; ++i) {
930     if (Mask[i] == UndefMaskElem)
931       continue;
932     if ((Elt != UndefMaskElem) && (Mask[i] != Mask[Elt]))
933       return UndefMaskElem;
934     if ((Mask[i] != UndefMaskElem) && (Elt == UndefMaskElem))
935       Elt = i;
936   }
937   return Elt;
938 }
939 
940 /***********************************************************************
941  * ShuffleVectorAnalyzer::getAsSplat : if shufflevector is a splat, get the
942  *      splatted input, with its vector index if the input is a vector
943  */
getAsSplat()944 ShuffleVectorAnalyzer::SplatInfo ShuffleVectorAnalyzer::getAsSplat()
945 {
946   Value *InVec1 = SI->getOperand(0);
947   Value *InVec2 = SI->getOperand(1);
948 
949   SmallVector<int, 16> MaskAsInts;
950   SI->getShuffleMask(MaskAsInts);
951   int ShuffleIdx = nEltSplatMask(MaskAsInts);
952   if (ShuffleIdx == UndefMaskElem)
953     return SplatInfo(nullptr, 0);
954 
955   // We have position of shuffleindex as output, turn it to real index
956   ShuffleIdx = MaskAsInts[ShuffleIdx];
957 
958   // The mask is a splat. Work out which element of which input vector
959   // it refers to.
960   int InVec1NumElements =
961       cast<IGCLLVM::FixedVectorType>(InVec1->getType())->getNumElements();
962   if (ShuffleIdx >= InVec1NumElements) {
963     ShuffleIdx -= InVec1NumElements;
964     InVec1 = InVec2;
965   }
966   if (auto IE = dyn_cast<InsertElementInst>(InVec1)) {
967     if (InVec1NumElements == 1 || isa<UndefValue>(IE->getOperand(0)))
968       return SplatInfo(IE->getOperand(1), 0);
969     // Even though this is a splat, the input vector has more than one
970     // element. IRBuilder::CreateVectorSplat does this. See if the input
971     // vector is the result of an insertelement at the right place, and
972     // if so return that. Otherwise we end up allocating
973     // an unnecessarily large register.
974     if (auto ConstIdx = dyn_cast<ConstantInt>(IE->getOperand(2)))
975       if (ConstIdx->getSExtValue() == ShuffleIdx)
976         return SplatInfo(IE->getOperand(1), 0);
977   }
978   return SplatInfo(InVec1, ShuffleIdx);
979 }
980 
serialize()981 Value *ShuffleVectorAnalyzer::serialize() {
982   unsigned Cost0 = getSerializeCost(0);
983   unsigned Cost1 = getSerializeCost(1);
984 
985   Value *Op0 = SI->getOperand(0);
986   Value *Op1 = SI->getOperand(1);
987   Value *V = Op0;
988   bool UseOp0AsBase = Cost0 <= Cost1;
989   if (!UseOp0AsBase)
990     V = Op1;
991 
992   // Expand or shink the initial value if sizes mismatch.
993   unsigned NElts =
994       cast<IGCLLVM::FixedVectorType>(SI->getType())->getNumElements();
995   unsigned M = cast<IGCLLVM::FixedVectorType>(V->getType())->getNumElements();
996   bool SkipBase = true;
997   if (M != NElts) {
998     if (auto C = dyn_cast<Constant>(V)) {
999       SmallVector<Constant *, 16> Vals;
1000       for (unsigned i = 0; i < NElts; ++i) {
1001         Type *Ty = cast<VectorType>(C->getType())->getElementType();
1002         Constant *Elt =
1003             (i < M) ? C->getAggregateElement(i) : UndefValue::get(Ty);
1004         Vals.push_back(Elt);
1005       }
1006       V = ConstantVector::get(Vals);
1007     } else {
1008       // Need to insert individual elements.
1009       V = UndefValue::get(SI->getType());
1010       SkipBase = false;
1011     }
1012   }
1013 
1014   IRBuilder<> Builder(SI);
1015   for (unsigned i = 0; i < NElts; ++i) {
1016     // Undef index returns -1.
1017     int idx = SI->getMaskValue(i);
1018     if (idx < 0)
1019       continue;
1020     if (SkipBase) {
1021       if (UseOp0AsBase && idx == i)
1022         continue;
1023       if (!UseOp0AsBase && idx == i + M)
1024         continue;
1025     }
1026 
1027     Value *Vi = nullptr;
1028     if (idx < (int)M)
1029       Vi = Builder.CreateExtractElement(Op0, idx, "");
1030     else
1031       Vi = Builder.CreateExtractElement(Op1, idx - M, "");
1032     if (!isa<UndefValue>(Vi))
1033       V = Builder.CreateInsertElement(V, Vi, i, "");
1034   }
1035 
1036   return V;
1037 }
1038 
getSerializeCost(unsigned i)1039 unsigned ShuffleVectorAnalyzer::getSerializeCost(unsigned i) {
1040   unsigned Cost = 0;
1041   Value *Op = SI->getOperand(i);
1042   if (!isa<Constant>(Op) && Op->getType() != SI->getType())
1043     Cost += cast<IGCLLVM::FixedVectorType>(Op->getType())->getNumElements();
1044 
1045   unsigned NElts =
1046       cast<IGCLLVM::FixedVectorType>(SI->getType())->getNumElements();
1047   for (unsigned j = 0; j < NElts; ++j) {
1048     // Undef index returns -1.
1049     int idx = SI->getMaskValue(j);
1050     if (idx < 0)
1051       continue;
1052     // Count the number of elements out of place.
1053     unsigned M =
1054         cast<IGCLLVM::FixedVectorType>(Op->getType())->getNumElements();
1055     if ((i == 0 && idx != j) || (i == 1 && idx != j + M))
1056       Cost++;
1057   }
1058 
1059   return Cost;
1060 }
1061 
IVSplitter(Instruction & Inst,const unsigned * BaseOpIdx)1062 IVSplitter::IVSplitter(Instruction &Inst, const unsigned *BaseOpIdx)
1063     : Inst(Inst) {
1064 
1065   ETy = Inst.getType();
1066   if (BaseOpIdx)
1067     ETy = Inst.getOperand(*BaseOpIdx)->getType();
1068 
1069   Len = 1;
1070   if (auto *EVTy = dyn_cast<IGCLLVM::FixedVectorType>(ETy)) {
1071     Len = EVTy->getNumElements();
1072     ETy = EVTy->getElementType();
1073   }
1074 
1075   VI32Ty = IGCLLVM::FixedVectorType::get(ETy->getInt32Ty(Inst.getContext()),
1076                                          Len * 2);
1077 }
1078 
describeSplit(RegionType RT,size_t ElNum)1079 IVSplitter::RegionTrait IVSplitter::describeSplit(RegionType RT, size_t ElNum) {
1080   RegionTrait Result;
1081   if (RT == RegionType::LoRegion || RT == RegionType::HiRegion) {
1082     // take every second element;
1083     Result.ElStride = 2;
1084     Result.ElOffset = (RT == RegionType::LoRegion) ? 0 : 1;
1085   }
1086   else if (RT == RegionType::FirstHalf || RT == RegionType::SecondHalf) {
1087     // take every element, sequentially
1088     Result.ElStride = 1;
1089     Result.ElOffset = (RT == RegionType::FirstHalf) ? 0 : ElNum;
1090   } else {
1091     IGC_ASSERT_EXIT_MESSAGE(0, "incorrect region type");
1092   }
1093   return Result;
1094 }
1095 
1096 Constant *
splitConstantVector(const SmallVectorImpl<Constant * > & KV32,RegionType RT)1097 IVSplitter::splitConstantVector(const SmallVectorImpl<Constant *> &KV32,
1098                                 RegionType RT) {
1099   IGC_ASSERT(KV32.size() % 2 == 0);
1100   SmallVector<Constant *, 16> Result;
1101   size_t ElNum = KV32.size() / 2;
1102   Result.reserve(ElNum);
1103   auto Split = describeSplit(RT, ElNum);
1104   for (size_t i = 0; i < ElNum; ++i) {
1105     size_t Offset = Split.ElOffset + i * Split.ElStride;
1106     IGC_ASSERT(Offset < KV32.size());
1107     Result.push_back(KV32[Offset]);
1108   }
1109   return ConstantVector::get(Result);
1110 }
1111 
createSplitRegion(Type * SrcTy,IVSplitter::RegionType RT)1112 Region IVSplitter::createSplitRegion(Type *SrcTy, IVSplitter::RegionType RT) {
1113   IGC_ASSERT(SrcTy->isVectorTy());
1114   IGC_ASSERT(SrcTy->getScalarType()->isIntegerTy(32));
1115   IGC_ASSERT(cast<IGCLLVM::FixedVectorType>(SrcTy)->getNumElements() % 2 == 0);
1116 
1117   size_t Len = cast<IGCLLVM::FixedVectorType>(SrcTy)->getNumElements() / 2;
1118 
1119   auto Split = describeSplit(RT, Len);
1120 
1121   Region R(SrcTy);
1122   R.Width = Len;
1123   R.NumElements = Len;
1124   R.VStride = 0;
1125   R.Stride = Split.ElStride;
1126   // offset is encoded in bytes
1127   R.Offset = Split.ElOffset * 4;
1128 
1129   return R;
1130 }
1131 
1132 // function takes 64-bit constant value (vector or scalar) and splits it
1133 // into an equivalent vector of 32-bit constant (as if it was Bitcast-ed)
convertI64ToI32(Constant & K,SmallVectorImpl<Constant * > & K32)1134 static void convertI64ToI32(Constant &K, SmallVectorImpl<Constant *> &K32) {
1135   auto I64To32 = [](const Constant &K) {
1136     // we expect only scalar types here
1137     IGC_ASSERT(!isa<VectorType>(K.getType()));
1138     IGC_ASSERT(K.getType()->isIntegerTy(64));
1139     auto *Ty32 = K.getType()->getInt32Ty(K.getContext());
1140     if (isa<UndefValue>(K)) {
1141       Constant *Undef = UndefValue::get(Ty32);
1142       return std::make_pair(Undef, Undef);
1143     }
1144     auto *KI = cast<ConstantInt>(&K);
1145     uint64_t Val64 = KI->getZExtValue();
1146     const auto UI32ValueMask = std::numeric_limits<uint32_t>::max();
1147     Constant *VLo =
1148         ConstantInt::get(Ty32, static_cast<uint32_t>(Val64 & UI32ValueMask));
1149     Constant *VHi = ConstantInt::get(Ty32, static_cast<uint32_t>(Val64 >> 32));
1150     return std::make_pair(VLo, VHi);
1151   };
1152 
1153   IGC_ASSERT(K32.empty());
1154   if (!isa<VectorType>(K.getType())) {
1155     auto V32 = I64To32(K);
1156     K32.push_back(V32.first);
1157     K32.push_back(V32.second);
1158     return;
1159   }
1160   unsigned ElNum =
1161       cast<IGCLLVM::FixedVectorType>(K.getType())->getNumElements();
1162   K32.reserve(2 * ElNum);
1163   for (unsigned i = 0; i < ElNum; ++i) {
1164     auto V32 = I64To32(*K.getAggregateElement(i));
1165     K32.push_back(V32.first);
1166     K32.push_back(V32.second);
1167   }
1168 }
1169 
1170 std::pair<Value *, Value *>
splitValue(Value & Val,RegionType RT1,const Twine & Name1,RegionType RT2,const Twine & Name2,bool FoldConstants)1171 IVSplitter::splitValue(Value &Val, RegionType RT1, const Twine &Name1,
1172                        RegionType RT2, const Twine &Name2, bool FoldConstants) {
1173   const auto &DL = Inst.getDebugLoc();
1174   auto BaseName = Inst.getName();
1175 
1176   IGC_ASSERT(Val.getType()->getScalarType()->isIntegerTy(64));
1177 
1178   if (FoldConstants && isa<Constant>(Val)) {
1179     SmallVector<Constant *, 32> KV32;
1180     convertI64ToI32(cast<Constant>(Val), KV32);
1181     Value *V1 = splitConstantVector(KV32, RT1);
1182     Value *V2 = splitConstantVector(KV32, RT2);
1183     return {V1, V2};
1184   }
1185   auto *ShreddedVal = new BitCastInst(&Val, VI32Ty, BaseName + ".iv32cast", &Inst);
1186   ShreddedVal->setDebugLoc(DL);
1187 
1188   auto R1 = createSplitRegion(VI32Ty, RT1);
1189   auto *V1 = R1.createRdRegion(ShreddedVal, BaseName + Name1, &Inst, DL);
1190 
1191   auto R2 = createSplitRegion(VI32Ty, RT2);
1192   auto *V2 = R2.createRdRegion(ShreddedVal, BaseName + Name2, &Inst, DL);
1193   return { V1, V2 };
1194 }
1195 
splitOperandLoHi(unsigned SourceIdx,bool FoldConstants)1196 IVSplitter::LoHiSplit IVSplitter::splitOperandLoHi(unsigned SourceIdx,
1197                                                    bool FoldConstants) {
1198 
1199   IGC_ASSERT(Inst.getNumOperands() > SourceIdx);
1200   return splitValueLoHi(*Inst.getOperand(SourceIdx), FoldConstants);
1201 }
splitOperandHalf(unsigned SourceIdx,bool FoldConstants)1202 IVSplitter::HalfSplit IVSplitter::splitOperandHalf(unsigned SourceIdx,
1203                                                    bool FoldConstants) {
1204 
1205   IGC_ASSERT(Inst.getNumOperands() > SourceIdx);
1206   return splitValueHalf(*Inst.getOperand(SourceIdx), FoldConstants);
1207 }
1208 
splitValueLoHi(Value & V,bool FoldConstants)1209 IVSplitter::LoHiSplit IVSplitter::splitValueLoHi(Value &V, bool FoldConstants) {
1210   auto Splitted = splitValue(V, RegionType::LoRegion, ".LoSplit",
1211                              RegionType::HiRegion, ".HiSplit", FoldConstants);
1212   return {Splitted.first, Splitted.second};
1213 }
splitValueHalf(Value & V,bool FoldConstants)1214 IVSplitter::HalfSplit IVSplitter::splitValueHalf(Value &V, bool FoldConstants) {
1215   auto Splitted =
1216       splitValue(V, RegionType::FirstHalf, ".FirstHalf", RegionType::SecondHalf,
1217                  ".SecondHalf", FoldConstants);
1218   return {Splitted.first, Splitted.second};
1219 }
1220 
combineSplit(Value & V1,Value & V2,RegionType RT1,RegionType RT2,const Twine & Name,bool Scalarize)1221 Value* IVSplitter::combineSplit(Value &V1, Value &V2, RegionType RT1,
1222                                 RegionType RT2, const Twine& Name,
1223                                 bool Scalarize) {
1224   const auto &DL = Inst.getDebugLoc();
1225 
1226   IGC_ASSERT(V1.getType() == V2.getType());
1227   IGC_ASSERT(V1.getType()->isVectorTy());
1228   IGC_ASSERT(cast<VectorType>(V1.getType())->getElementType()->isIntegerTy(32));
1229 
1230   // create the write-regions
1231   auto R1 = createSplitRegion(VI32Ty, RT1);
1232   auto *UndefV = UndefValue::get(VI32Ty);
1233   auto *W1 = R1.createWrRegion(UndefV, &V1, Name + "partial_join", &Inst, DL);
1234 
1235   auto R2 = createSplitRegion(VI32Ty, RT2);
1236   auto *W2 = R2.createWrRegion(W1, &V2, Name + "joined", &Inst, DL);
1237 
1238   auto *V64Ty =
1239       IGCLLVM::FixedVectorType::get(ETy->getInt64Ty(Inst.getContext()), Len);
1240   auto *Result = new BitCastInst(W2, V64Ty, Name, &Inst);
1241   Result->setDebugLoc(DL);
1242 
1243   if (Scalarize) {
1244     IGC_ASSERT(
1245         cast<IGCLLVM::FixedVectorType>(Result->getType())->getNumElements() ==
1246         1);
1247     Result = new BitCastInst(Result, ETy->getInt64Ty(Inst.getContext()),
1248                              Name + "recast", &Inst);
1249     Result->setDebugLoc(DL);
1250   }
1251   return Result;
1252 
1253 }
combineLoHiSplit(const LoHiSplit & Split,const Twine & Name,bool Scalarize)1254 Value *IVSplitter::combineLoHiSplit(const LoHiSplit &Split, const Twine &Name,
1255                                     bool Scalarize) {
1256   IGC_ASSERT(Split.Lo);
1257   IGC_ASSERT(Split.Hi);
1258 
1259   return combineSplit(*Split.Lo, *Split.Hi, RegionType::LoRegion,
1260                       RegionType::HiRegion, Name, Scalarize);
1261 }
1262 
combineHalfSplit(const HalfSplit & Split,const Twine & Name,bool Scalarize)1263 Value *IVSplitter::combineHalfSplit(const HalfSplit &Split, const Twine &Name,
1264                                     bool Scalarize) {
1265   IGC_ASSERT(Split.Left);
1266   IGC_ASSERT(Split.Right);
1267 
1268   return combineSplit(*Split.Left, *Split.Right, RegionType::FirstHalf,
1269                       RegionType::SecondHalf, Name, Scalarize);
1270 }
1271 /***********************************************************************
1272  * adjustPhiNodesForBlockRemoval : adjust phi nodes when removing a block
1273  *
1274  * Enter:   Succ = the successor block to adjust phi nodes in
1275  *          BB = the block being removed
1276  *
1277  * This modifies each phi node in Succ as follows: the incoming for BB is
1278  * replaced by an incoming for each of BB's predecessors.
1279  */
adjustPhiNodesForBlockRemoval(BasicBlock * Succ,BasicBlock * BB)1280 void genx::adjustPhiNodesForBlockRemoval(BasicBlock *Succ, BasicBlock *BB)
1281 {
1282   for (auto i = Succ->begin(), e = Succ->end(); i != e; ++i) {
1283     auto Phi = dyn_cast<PHINode>(&*i);
1284     if (!Phi)
1285       break;
1286     // For this phi node, get the incoming for BB.
1287     int Idx = Phi->getBasicBlockIndex(BB);
1288     IGC_ASSERT(Idx >= 0);
1289     Value *Incoming = Phi->getIncomingValue(Idx);
1290     // Iterate through BB's predecessors. For the first one, replace the
1291     // incoming block with the predecessor. For subsequent ones, we need
1292     // to add new phi incomings.
1293     auto pi = pred_begin(BB), pe = pred_end(BB);
1294     IGC_ASSERT(pi != pe);
1295     Phi->setIncomingBlock(Idx, *pi);
1296     for (++pi; pi != pe; ++pi)
1297       Phi->addIncoming(Incoming, *pi);
1298   }
1299 }
1300 
1301 /***********************************************************************
1302  * sinkAdd : sink add(s) in address calculation
1303  *
1304  * Enter:   IdxVal = the original index value
1305  *
1306  * Return:  the new calculation for the index value
1307  *
1308  * This detects the case when a variable index in a region or element access
1309  * is one or more constant add/subs then some mul/shl/truncs. It sinks
1310  * the add/subs into a single add after the mul/shl/truncs, so the add
1311  * stands a chance of being baled in as a constant offset in the region.
1312  *
1313  * If add sinking is successfully applied, it may leave now unused
1314  * instructions behind, which need tidying by a later dead code removal
1315  * pass.
1316  */
sinkAdd(Value * V)1317 Value *genx::sinkAdd(Value *V) {
1318   Instruction *IdxVal = dyn_cast<Instruction>(V);
1319   if (!IdxVal)
1320     return V;
1321   // Collect the scale/trunc/add/sub/or instructions.
1322   int Offset = 0;
1323   SmallVector<Instruction *, 8> ScaleInsts;
1324   Instruction *Inst = IdxVal;
1325   int Scale = 1;
1326   bool NeedChange = false;
1327   for (;;) {
1328     if (isa<TruncInst>(Inst))
1329       ScaleInsts.push_back(Inst);
1330     else {
1331       if (!isa<BinaryOperator>(Inst))
1332         break;
1333       if (ConstantInt *CI = dyn_cast<ConstantInt>(Inst->getOperand(1))) {
1334         if (Inst->getOpcode() == Instruction::Mul) {
1335           Scale *= CI->getSExtValue();
1336           ScaleInsts.push_back(Inst);
1337         } else if (Inst->getOpcode() == Instruction::Shl) {
1338           Scale <<= CI->getSExtValue();
1339           ScaleInsts.push_back(Inst);
1340         } else if (Inst->getOpcode() == Instruction::Add) {
1341           Offset += CI->getSExtValue() * Scale;
1342           if (V != Inst)
1343             NeedChange = true;
1344         } else if (Inst->getOpcode() == Instruction::Sub) {
1345           Offset -= CI->getSExtValue() * Scale;
1346           if (IdxVal != Inst)
1347             NeedChange = true;
1348         } else if(Inst->getOpcode() == Instruction::Or) {
1349           if (!haveNoCommonBitsSet(Inst->getOperand(0),
1350                                   Inst->getOperand(1),
1351                                   Inst->getModule()->getDataLayout()))
1352             break;
1353           Offset += CI->getSExtValue() * Scale;
1354           if (V != Inst)
1355             NeedChange = true;
1356         } else
1357           break;
1358       } else
1359         break;
1360     }
1361     Inst = dyn_cast<Instruction>(Inst->getOperand(0));
1362     if (!Inst)
1363       return V;
1364   }
1365   if (!NeedChange)
1366     return V;
1367   // Clone the scale and trunc instructions, starting with the value that
1368   // was input to the add(s).
1369   for (SmallVectorImpl<Instruction *>::reverse_iterator i = ScaleInsts.rbegin(),
1370                                                         e = ScaleInsts.rend();
1371        i != e; ++i) {
1372     Instruction *Clone = (*i)->clone();
1373     Clone->insertBefore(IdxVal);
1374     Clone->setName((*i)->getName());
1375     Clone->setOperand(0, Inst);
1376     Inst = Clone;
1377   }
1378   // Create a new add instruction.
1379   Inst = BinaryOperator::Create(
1380       Instruction::Add, Inst,
1381       ConstantInt::get(Inst->getType(), (int64_t)Offset, true /*isSigned*/),
1382       Twine("addr_add"), IdxVal);
1383   Inst->setDebugLoc(IdxVal->getDebugLoc());
1384   return Inst;
1385 }
1386 
1387 /***********************************************************************
1388 * reorderBlocks : reorder blocks to increase fallthrough, and specifically
1389 *    to satisfy the requirements of SIMD control flow
1390 */
1391 #define SUCCSZANY     (true)
1392 #define SUCCHASINST   (succ->size() > 1)
1393 #define SUCCNOINST    (succ->size() <= 1)
1394 #define SUCCANYLOOP   (true)
1395 
1396 #define PUSHSUCC(BLK, C1, C2) \
1397         for(succ_iterator succIter = succ_begin(BLK), succEnd = succ_end(BLK); \
1398           succIter!=succEnd; ++succIter) {                                   \
1399           llvm::BasicBlock *succ = *succIter;                                \
1400           if (!visitSet.count(succ) && C1 && C2) {                           \
1401             visitVec.push_back(succ);                                        \
1402             visitSet.insert(succ);                                           \
1403             break;                                                           \
1404           }                                                                  \
1405         }
1406 
HasSimdGotoJoinInBlock(BasicBlock * BB)1407 static bool HasSimdGotoJoinInBlock(BasicBlock *BB)
1408 {
1409   for (BasicBlock::iterator BBI = BB->begin(),
1410                             BBE = BB->end();
1411        BBI != BBE; ++BBI) {
1412     auto IID = GenXIntrinsic::getGenXIntrinsicID(&*BBI);
1413     if (IID == GenXIntrinsic::genx_simdcf_goto ||
1414         IID == GenXIntrinsic::genx_simdcf_join)
1415       return true;
1416   }
1417   return false;
1418 }
1419 
LayoutBlocks(Function & func,LoopInfo & LI)1420 void genx::LayoutBlocks(Function &func, LoopInfo &LI)
1421 {
1422   std::vector<llvm::BasicBlock*> visitVec;
1423   std::set<llvm::BasicBlock*> visitSet;
1424   // Insertion Position per loop header
1425   std::map<llvm::BasicBlock*, llvm::BasicBlock*> InsPos;
1426 
1427   llvm::BasicBlock* entry = &(func.getEntryBlock());
1428   visitVec.push_back(entry);
1429   visitSet.insert(entry);
1430   InsPos[entry] = entry;
1431 
1432   while (!visitVec.empty()) {
1433     llvm::BasicBlock* blk = visitVec.back();
1434     llvm::Loop *curLoop = LI.getLoopFor(blk);
1435     if (curLoop) {
1436       auto hd = curLoop->getHeader();
1437       if (blk == hd && InsPos.find(hd) == InsPos.end()) {
1438         InsPos[blk] = blk;
1439       }
1440     }
1441     // push: time for DFS visit
1442     PUSHSUCC(blk, SUCCANYLOOP, SUCCNOINST);
1443     if (blk != visitVec.back())
1444       continue;
1445     // push: time for DFS visit
1446     PUSHSUCC(blk, SUCCANYLOOP, SUCCHASINST);
1447     // pop: time to move the block to the right location
1448     if (blk == visitVec.back()) {
1449       visitVec.pop_back();
1450       if (curLoop) {
1451         auto hd = curLoop->getHeader();
1452         if (blk != hd) {
1453           // move the block to the beginning of the loop
1454           auto insp = InsPos[hd];
1455           IGC_ASSERT(insp);
1456           if (blk != insp) {
1457             blk->moveBefore(insp);
1458             InsPos[hd] = blk;
1459           }
1460         }
1461         else {
1462           // move the entire loop to the beginning of
1463           // the parent loop
1464           auto LoopStart = InsPos[hd];
1465           IGC_ASSERT(LoopStart);
1466           auto PaLoop = curLoop->getParentLoop();
1467           auto PaHd = PaLoop ? PaLoop->getHeader() : entry;
1468           auto insp = InsPos[PaHd];
1469           if (LoopStart == hd) {
1470             // single block loop
1471             hd->moveBefore(insp);
1472           }
1473           else {
1474             // loop-header is not moved yet, so should be at the end
1475             // use splice
1476             llvm::Function::BasicBlockListType& BBList = func.getBasicBlockList();
1477             BBList.splice(insp->getIterator(), BBList, LoopStart->getIterator(),
1478               hd->getIterator());
1479             hd->moveBefore(LoopStart);
1480           }
1481           InsPos[PaHd] = hd;
1482         }
1483       }
1484       else {
1485         auto insp = InsPos[entry];
1486         if (blk != insp) {
1487           blk->moveBefore(insp);
1488           InsPos[entry] = blk;
1489         }
1490       }
1491     }
1492   }
1493 
1494   // fix the loop-exit pattern, put break-blocks into the loop
1495   for (llvm::Function::iterator blkIter = func.begin(), blkEnd = func.end();
1496        blkIter != blkEnd; ++blkIter) {
1497     llvm::BasicBlock *blk = &(*blkIter);
1498     llvm::Loop *curLoop = LI.getLoopFor(blk);
1499     bool allPredLoopExit = true;
1500     unsigned numPreds = 0;
1501     llvm::SmallPtrSet<llvm::BasicBlock *, 4> predSet;
1502     for (pred_iterator predIter = pred_begin(blk), predEnd = pred_end(blk);
1503          predIter != predEnd; ++predIter) {
1504       llvm::BasicBlock *pred = *predIter;
1505       numPreds++;
1506       llvm::Loop *predLoop = LI.getLoopFor(pred);
1507       if (curLoop == predLoop) {
1508         llvm::BasicBlock *predPred = pred->getSinglePredecessor();
1509         if (predPred) {
1510           llvm::Loop *predPredLoop = LI.getLoopFor(predPred);
1511           if (predPredLoop != curLoop &&
1512               (!curLoop || curLoop->contains(predPredLoop))) {
1513             if (!HasSimdGotoJoinInBlock(pred)) {
1514               predSet.insert(pred);
1515             } else {
1516               allPredLoopExit = false;
1517               break;
1518             }
1519           }
1520         }
1521       } else if (!curLoop || curLoop->contains(predLoop))
1522         continue;
1523       else {
1524         allPredLoopExit = false;
1525         break;
1526       }
1527     }
1528     if (allPredLoopExit && numPreds > 1) {
1529       for (SmallPtrSet<BasicBlock *, 4>::iterator predIter = predSet.begin(),
1530                                                   predEnd = predSet.end();
1531            predIter != predEnd; ++predIter) {
1532         llvm::BasicBlock *pred = *predIter;
1533         llvm::BasicBlock *predPred = pred->getSinglePredecessor();
1534         IGC_ASSERT(predPred);
1535         pred->moveAfter(predPred);
1536       }
1537     }
1538   }
1539 }
1540 
LayoutBlocks(Function & func)1541 void genx::LayoutBlocks(Function &func)
1542 {
1543   std::vector<llvm::BasicBlock*> visitVec;
1544   std::set<llvm::BasicBlock*> visitSet;
1545   // Reorder basic block to allow more fall-through
1546   llvm::BasicBlock* entry = &(func.getEntryBlock());
1547   visitVec.push_back(entry);
1548   visitSet.insert(entry);
1549 
1550   while (!visitVec.empty()) {
1551     llvm::BasicBlock* blk = visitVec.back();
1552     // push in the empty successor
1553     PUSHSUCC(blk, SUCCANYLOOP, SUCCNOINST);
1554     if (blk != visitVec.back())
1555       continue;
1556     // push in the other successor
1557     PUSHSUCC(blk, SUCCANYLOOP, SUCCHASINST);
1558     // pop
1559     if (blk == visitVec.back()) {
1560       visitVec.pop_back();
1561       if (blk != entry) {
1562         blk->moveBefore(entry);
1563         entry = blk;
1564       }
1565     }
1566   }
1567 }
1568 
1569 // normalize g_load with bitcasts.
1570 //
1571 // When a single g_load is being bitcast'ed to different types, clone g_loads.
normalizeGloads(Instruction * Inst)1572 bool genx::normalizeGloads(Instruction *Inst) {
1573   IGC_ASSERT(isa<LoadInst>(Inst));
1574   auto LI = cast<LoadInst>(Inst);
1575   if (getUnderlyingGlobalVariable(LI->getPointerOperand()) == nullptr)
1576     return false;
1577 
1578   // collect all uses connected by bitcasts.
1579   std::set<BitCastInst *> Visited;
1580   // Uses of this loads groupped by the use type.
1581   llvm::MapVector<Type *, std::vector<BitCastInst *>> Uses;
1582   // The working list.
1583   std::vector<BitCastInst *> Insts;
1584 
1585   for (auto UI : LI->users())
1586     if (auto BI = dyn_cast<BitCastInst>(UI))
1587       Insts.push_back(BI);
1588 
1589   while (!Insts.empty()) {
1590     BitCastInst *BCI = Insts.back();
1591     Insts.pop_back();
1592     if (Visited.count(BCI))
1593       continue;
1594 
1595     Uses[BCI->getType()].push_back(BCI);
1596     for (auto UI : BCI->users())
1597       if (auto BI = dyn_cast<BitCastInst>(UI))
1598         Insts.push_back(BI);
1599   }
1600 
1601   // There are more than two uses; clone loads that can fold bitcasts.
1602   if (Uses.size() <= 1)
1603     return false;
1604 
1605   // %0 = load gv
1606   // %1 = bitcast %0 to t1
1607   // %2 - bitcast %1 to t2
1608   //
1609   // ==>
1610   // %0 = load gv
1611   // %0.1 = load gv
1612   // %1 = bitcast %0 to t1
1613   // %2 - bitcast %0.1 to t2
1614   Instruction *LInst = LI;
1615   for (auto I = Uses.begin(); I != Uses.end(); ++I) {
1616     Type *Ty = I->first;
1617     if (LInst == nullptr) {
1618       LInst = LI->clone();
1619       LInst->insertAfter(LI);
1620     }
1621     Instruction *NewCI = new BitCastInst(LInst, Ty, ".clone", LInst);
1622     NewCI->moveAfter(LInst);
1623     auto &BInsts = I->second;
1624     for (auto BI : BInsts)
1625       BI->replaceAllUsesWith(NewCI);
1626     LInst = nullptr;
1627   }
1628   return true;
1629 }
1630 
1631 // fold bitcast instruction into Store by change pointer type.
foldBitCastInst(Instruction * Inst)1632 Instruction *genx::foldBitCastInst(Instruction *Inst) {
1633   IGC_ASSERT(isa<LoadInst>(Inst) || isa<StoreInst>(Inst));
1634   auto LI = dyn_cast<LoadInst>(Inst);
1635   auto SI = dyn_cast<StoreInst>(Inst);
1636 
1637   Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand();
1638   GlobalVariable *GV = getUnderlyingGlobalVariable(Ptr);
1639   if (!GV)
1640     return nullptr;
1641 
1642   if (SI) {
1643     Value *Val = SI->getValueOperand();
1644     if (auto CI = dyn_cast<BitCastInst>(Val)) {
1645       auto SrcTy = CI->getSrcTy();
1646       auto NewPtrTy = PointerType::get(SrcTy, SI->getPointerAddressSpace());
1647       auto NewPtr = ConstantExpr::getBitCast(GV, NewPtrTy);
1648       StoreInst *NewSI = new StoreInst(CI->getOperand(0), NewPtr,
1649                                        /*volatile*/ SI->isVolatile(), Inst);
1650       NewSI->takeName(SI);
1651       NewSI->setDebugLoc(Inst->getDebugLoc());
1652       Inst->eraseFromParent();
1653       return NewSI;
1654     }
1655   } else if (LI && LI->hasOneUse()) {
1656     if (auto CI = dyn_cast<BitCastInst>(LI->user_back())) {
1657       auto NewPtrTy = PointerType::get(CI->getType(), LI->getPointerAddressSpace());
1658       auto NewPtr = ConstantExpr::getBitCast(GV, NewPtrTy);
1659       auto NewLI = new LoadInst(NewPtrTy->getPointerElementType(), NewPtr, "",
1660                                 /*volatile*/ LI->isVolatile(), Inst);
1661       NewLI->takeName(LI);
1662       NewLI->setDebugLoc(LI->getDebugLoc());
1663       CI->replaceAllUsesWith(NewLI);
1664       LI->replaceAllUsesWith(UndefValue::get(LI->getType()));
1665       LI->eraseFromParent();
1666       return NewLI;
1667     }
1668   }
1669 
1670   return nullptr;
1671 }
1672 
getUnderlyingGlobalVariable(const Value * V)1673 const GlobalVariable *genx::getUnderlyingGlobalVariable(const Value *V) {
1674   while (auto *BI = dyn_cast<BitCastInst>(V))
1675     V = BI->getOperand(0);
1676   while (auto *CE = dyn_cast_or_null<ConstantExpr>(V)) {
1677     if (CE->getOpcode() == CastInst::BitCast)
1678       V = CE->getOperand(0);
1679     else
1680       break;
1681   }
1682   return dyn_cast_or_null<GlobalVariable>(V);
1683 }
1684 
getUnderlyingGlobalVariable(Value * V)1685 GlobalVariable *genx::getUnderlyingGlobalVariable(Value *V) {
1686   return const_cast<GlobalVariable *>(
1687       getUnderlyingGlobalVariable(const_cast<const Value *>(V)));
1688 }
1689 
getUnderlyingGlobalVariable(const LoadInst * LI)1690 const GlobalVariable *genx::getUnderlyingGlobalVariable(const LoadInst *LI) {
1691   return getUnderlyingGlobalVariable(LI->getPointerOperand());
1692 }
1693 
getUnderlyingGlobalVariable(LoadInst * LI)1694 GlobalVariable *genx::getUnderlyingGlobalVariable(LoadInst *LI) {
1695   return getUnderlyingGlobalVariable(LI->getPointerOperand());
1696 }
1697 
isGlobalStore(Instruction * I)1698 bool genx::isGlobalStore(Instruction *I) {
1699   IGC_ASSERT(I);
1700   if (auto *SI = dyn_cast<StoreInst>(I))
1701     return isGlobalStore(SI);
1702   return false;
1703 }
1704 
isGlobalStore(StoreInst * ST)1705 bool genx::isGlobalStore(StoreInst *ST) {
1706   IGC_ASSERT(ST);
1707   return getUnderlyingGlobalVariable(ST->getPointerOperand()) != nullptr;
1708 }
1709 
isGlobalLoad(Instruction * I)1710 bool genx::isGlobalLoad(Instruction *I) {
1711   IGC_ASSERT(I);
1712   if (auto *LI = dyn_cast<LoadInst>(I))
1713     return isGlobalLoad(LI);
1714   return false;
1715 }
1716 
isGlobalLoad(LoadInst * LI)1717 bool genx::isGlobalLoad(LoadInst *LI) {
1718   IGC_ASSERT(LI);
1719   return getUnderlyingGlobalVariable(LI->getPointerOperand()) != nullptr;
1720 }
1721 
isLegalValueForGlobalStore(Value * V,Value * StorePtr)1722 bool genx::isLegalValueForGlobalStore(Value *V, Value *StorePtr) {
1723   // Value should be wrregion.
1724   auto *Wrr = dyn_cast<CallInst>(V);
1725   if (!Wrr || !GenXIntrinsic::isWrRegion(Wrr))
1726     return false;
1727 
1728   // With old value obtained from load instruction with StorePtr.
1729   Value *OldVal =
1730       Wrr->getArgOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
1731   auto *LI = dyn_cast<LoadInst>(OldVal);
1732   return LI && (getUnderlyingGlobalVariable(LI->getPointerOperand()) ==
1733                 getUnderlyingGlobalVariable(StorePtr));
1734 }
1735 
isGlobalStoreLegal(StoreInst * ST)1736 bool genx::isGlobalStoreLegal(StoreInst *ST) {
1737   IGC_ASSERT(isGlobalStore(ST));
1738   return isLegalValueForGlobalStore(ST->getValueOperand(),
1739                                     ST->getPointerOperand());
1740 }
1741 
1742 // The following bale will produce identity moves.
1743 // %a0 = load m
1744 // %b0 = load m
1745 // bale {
1746 //   %a1 = rrd %a0, R
1747 //   %b1 = wrr %b0, %a1, R
1748 //   store %b1, m
1749 // }
1750 //
isIdentityBale(const Bale & B)1751 bool genx::isIdentityBale(const Bale &B) {
1752   if (!B.endsWithGStore())
1753     return false;
1754 
1755   StoreInst *ST = cast<StoreInst>(B.getHead()->Inst);
1756   if (B.size() == 1) {
1757     // The value to be stored should be a load from the same global.
1758     auto LI = dyn_cast<LoadInst>(ST->getOperand(0));
1759     return LI && getUnderlyingGlobalVariable(LI->getOperand(0)) ==
1760                      getUnderlyingGlobalVariable(ST->getOperand(1));
1761   }
1762   if (B.size() != 3)
1763     return false;
1764 
1765   CallInst *B1 = dyn_cast<CallInst>(ST->getValueOperand());
1766   GlobalVariable *GV = getUnderlyingGlobalVariable(ST->getPointerOperand());
1767   if (!GenXIntrinsic::isWrRegion(B1) || !GV)
1768     return false;
1769   IGC_ASSERT(B1);
1770   auto B0 = dyn_cast<LoadInst>(B1->getArgOperand(0));
1771   if (!B0 || GV != getUnderlyingGlobalVariable(B0->getPointerOperand()))
1772     return false;
1773 
1774   CallInst *A1 = dyn_cast<CallInst>(B1->getArgOperand(1));
1775   if (!GenXIntrinsic::isRdRegion(A1))
1776     return false;
1777   IGC_ASSERT(A1);
1778   LoadInst *A0 = dyn_cast<LoadInst>(A1->getArgOperand(0));
1779   if (!A0 || GV != getUnderlyingGlobalVariable(A0->getPointerOperand()))
1780     return false;
1781 
1782   Region R1 = makeRegionFromBaleInfo(A1, BaleInfo());
1783   Region R2 = makeRegionFromBaleInfo(B1, BaleInfo());
1784   return R1 == R2;
1785 }
1786 
1787 // Check that region can be represented as raw operand.
isValueRegionOKForRaw(Value * V,bool IsWrite,const GenXSubtarget * ST)1788 bool genx::isValueRegionOKForRaw(Value *V, bool IsWrite,
1789                                  const GenXSubtarget *ST) {
1790   IGC_ASSERT(V);
1791   switch (GenXIntrinsic::getGenXIntrinsicID(V)) {
1792   case GenXIntrinsic::genx_rdregioni:
1793   case GenXIntrinsic::genx_rdregionf:
1794     if (IsWrite)
1795       return false;
1796     break;
1797   case GenXIntrinsic::genx_wrregioni:
1798   case GenXIntrinsic::genx_wrregionf:
1799     if (!IsWrite)
1800       return false;
1801     break;
1802   default:
1803     return false;
1804   }
1805   Region R = makeRegionFromBaleInfo(cast<Instruction>(V), BaleInfo());
1806   return isRegionOKForRaw(R, ST);
1807 }
1808 
isRegionOKForRaw(const genx::Region & R,const GenXSubtarget * ST)1809 bool genx::isRegionOKForRaw(const genx::Region &R, const GenXSubtarget *ST) {
1810   unsigned GRFWidth = ST ? ST->getGRFByteSize() : 32;
1811   if (R.Indirect)
1812     return false;
1813   else if (R.Offset & (GRFWidth - 1)) // GRF boundary check
1814     return false;
1815   if (R.Width != R.NumElements)
1816     return false;
1817   if (R.Stride != 1)
1818     return false;
1819   return true;
1820 }
1821 
skipOptWithLargeBlock(FunctionGroup & FG)1822 bool genx::skipOptWithLargeBlock(FunctionGroup &FG) {
1823   for (auto fgi = FG.begin(), fge = FG.end(); fgi != fge; ++fgi) {
1824     auto F = *fgi;
1825     if (skipOptWithLargeBlock(*F))
1826       return true;
1827   }
1828   return false;
1829 }
1830 
getInlineAsmCodes(const InlineAsm::ConstraintInfo & Info)1831 std::string genx::getInlineAsmCodes(const InlineAsm::ConstraintInfo &Info) {
1832   return Info.Codes.front();
1833 }
1834 
isInlineAsmMatchingInputConstraint(const InlineAsm::ConstraintInfo & Info)1835 bool genx::isInlineAsmMatchingInputConstraint(
1836     const InlineAsm::ConstraintInfo &Info) {
1837   return isdigit(Info.Codes.front()[0]);
1838 }
1839 
getInlineAsmConstraintType(StringRef Codes)1840 genx::ConstraintType genx::getInlineAsmConstraintType(StringRef Codes) {
1841   return llvm::StringSwitch<genx::ConstraintType>(Codes)
1842       .Case("r", ConstraintType::Constraint_r)
1843       .Case("rw", ConstraintType::Constraint_rw)
1844       .Case("i", ConstraintType::Constraint_i)
1845       .Case("n", ConstraintType::Constraint_n)
1846       .Case("F", ConstraintType::Constraint_F)
1847       .Case("cr", ConstraintType::Constraint_cr)
1848       .Case("a", ConstraintType::Constraint_a)
1849       .Default(ConstraintType::Constraint_unknown);
1850 }
1851 
1852 unsigned
getInlineAsmMatchedOperand(const InlineAsm::ConstraintInfo & Info)1853 genx::getInlineAsmMatchedOperand(const InlineAsm::ConstraintInfo &Info) {
1854   IGC_ASSERT_MESSAGE(genx::isInlineAsmMatchingInputConstraint(Info),
1855     "Matching input expected");
1856   int OperandValue = std::stoi(Info.Codes.front());
1857   IGC_ASSERT(OperandValue >= 0);
1858   return OperandValue;
1859 }
1860 
getGenXInlineAsmInfo(MDNode * MD)1861 std::vector<GenXInlineAsmInfo> genx::getGenXInlineAsmInfo(MDNode *MD) {
1862   std::vector<GenXInlineAsmInfo> Result;
1863   for (auto &MDOp : MD->operands()) {
1864     auto EntryMD = dyn_cast<MDTuple>(MDOp);
1865     IGC_ASSERT_MESSAGE(EntryMD, "error setting metadata for inline asm");
1866     IGC_ASSERT_MESSAGE(EntryMD->getNumOperands() == 3,
1867       "error setting metadata for inline asm");
1868     ConstantAsMetadata *Op0 =
1869         dyn_cast<ConstantAsMetadata>(EntryMD->getOperand(0));
1870     ConstantAsMetadata *Op1 =
1871         dyn_cast<ConstantAsMetadata>(EntryMD->getOperand(1));
1872     ConstantAsMetadata *Op2 =
1873         dyn_cast<ConstantAsMetadata>(EntryMD->getOperand(2));
1874     IGC_ASSERT_MESSAGE(Op0, "error setting metadata for inline asm");
1875     IGC_ASSERT_MESSAGE(Op1, "error setting metadata for inline asm");
1876     IGC_ASSERT_MESSAGE(Op2, "error setting metadata for inline asm");
1877     auto CTy = static_cast<genx::ConstraintType>(
1878         cast<ConstantInt>(Op0->getValue())->getZExtValue());
1879     Result.emplace_back(CTy, cast<ConstantInt>(Op1->getValue())->getSExtValue(),
1880                         cast<ConstantInt>(Op2->getValue())->getZExtValue());
1881   }
1882   return Result;
1883 }
1884 
getGenXInlineAsmInfo(CallInst * CI)1885 std::vector<GenXInlineAsmInfo> genx::getGenXInlineAsmInfo(CallInst *CI) {
1886   IGC_ASSERT_MESSAGE(CI->isInlineAsm(), "Inline asm expected");
1887   MDNode *MD = CI->getMetadata(genx::MD_genx_inline_asm_info);
1888   // empty constraint info
1889   if (!MD) {
1890     auto *IA = cast<InlineAsm>(IGCLLVM::getCalledValue(CI));
1891     IGC_ASSERT_MESSAGE(IA->getConstraintString().empty(),
1892       "No info only for empty constraint string");
1893     (void)IA;
1894     return std::vector<GenXInlineAsmInfo>();
1895   }
1896   return genx::getGenXInlineAsmInfo(MD);
1897 }
1898 
hasConstraintOfType(const std::vector<GenXInlineAsmInfo> & ConstraintsInfo,genx::ConstraintType CTy)1899 bool genx::hasConstraintOfType(
1900     const std::vector<GenXInlineAsmInfo> &ConstraintsInfo,
1901     genx::ConstraintType CTy) {
1902   return llvm::any_of(ConstraintsInfo, [&](const GenXInlineAsmInfo &Info) {
1903     return Info.getConstraintType() == CTy;
1904   });
1905 }
1906 
getInlineAsmNumOutputs(CallInst * CI)1907 unsigned genx::getInlineAsmNumOutputs(CallInst *CI) {
1908   IGC_ASSERT_MESSAGE(CI->isInlineAsm(), "Inline asm expected");
1909   unsigned NumOutputs;
1910   if (CI->getType()->isVoidTy())
1911     NumOutputs = 0;
1912   else if (auto ST = dyn_cast<StructType>(CI->getType()))
1913     NumOutputs = ST->getNumElements();
1914   else
1915     NumOutputs = 1;
1916   return NumOutputs;
1917 }
1918 
1919 /* for <1 x Ty> returns Ty
1920  * for Ty returns <1 x Ty>
1921  * other cases are unsupported
1922  */
getCorrespondingVectorOrScalar(Type * Ty)1923 Type *genx::getCorrespondingVectorOrScalar(Type *Ty) {
1924   if (Ty->isVectorTy()) {
1925     IGC_ASSERT_MESSAGE(
1926         cast<IGCLLVM::FixedVectorType>(Ty)->getNumElements() == 1,
1927         "wrong argument: scalar or degenerate vector is expected");
1928     return Ty->getScalarType();
1929   }
1930   return IGCLLVM::FixedVectorType::get(Ty, 1);
1931 }
1932 
1933 // info is at main template function
scalarizeOrVectorizeIfNeeded(Instruction * Inst,Type * RefType)1934 CastInst *genx::scalarizeOrVectorizeIfNeeded(Instruction *Inst, Type *RefType) {
1935   return scalarizeOrVectorizeIfNeeded(Inst, &RefType, std::next(&RefType));
1936 }
1937 
1938 // info is at main template function
scalarizeOrVectorizeIfNeeded(Instruction * Inst,Instruction * InstToReplace)1939 CastInst *genx::scalarizeOrVectorizeIfNeeded(Instruction *Inst,
1940   Instruction *InstToReplace) {
1941   return scalarizeOrVectorizeIfNeeded(Inst, &InstToReplace, std::next(&InstToReplace));
1942 }
1943 
getFunctionPointerFunc(Value * V)1944 Function *genx::getFunctionPointerFunc(Value *V) {
1945   Instruction *I = nullptr;
1946   for (; (I = dyn_cast<CastInst>(V)); V = I->getOperand(0))
1947     ;
1948   ConstantExpr *CE = nullptr;
1949   for (; (CE = dyn_cast<ConstantExpr>(V)) &&
1950          (CE->getOpcode() == Instruction::ExtractElement || CE->isCast());
1951        V = CE->getOperand(0))
1952     ;
1953   if (auto *F = dyn_cast<Function>(V))
1954     return F;
1955   if (auto *CV = dyn_cast<ConstantVector>(V); CV && CV->getSplatValue())
1956     return getFunctionPointerFunc(CV->getSplatValue());
1957   return nullptr;
1958 }
1959 
isFuncPointerVec(Value * V)1960 bool genx::isFuncPointerVec(Value *V) {
1961   bool Res = false;
1962   if (V->getType()->isVectorTy() && isa<ConstantExpr>(V) &&
1963       cast<ConstantExpr>(V)->getOpcode() == Instruction::BitCast)
1964     Res = isFuncPointerVec(cast<ConstantExpr>(V)->getOperand(0));
1965   else if (ConstantVector *Vec = dyn_cast<ConstantVector>(V))
1966     Res = std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
1967       return getFunctionPointerFunc(V) != nullptr;
1968     });
1969   return Res;
1970 }
1971 
getLogAlignment(VISA_Align Align,unsigned GRFWidth)1972 unsigned genx::getLogAlignment(VISA_Align Align, unsigned GRFWidth) {
1973   switch (Align) {
1974   case ALIGN_BYTE:
1975     return Log2_32(ByteBytes);
1976   case ALIGN_WORD:
1977     return Log2_32(WordBytes);
1978   case ALIGN_DWORD:
1979     return Log2_32(DWordBytes);
1980   case ALIGN_QWORD:
1981     return Log2_32(QWordBytes);
1982   case ALIGN_OWORD:
1983     return Log2_32(OWordBytes);
1984   case ALIGN_GRF:
1985     return Log2_32(GRFWidth);
1986   case ALIGN_2_GRF:
1987     return Log2_32(GRFWidth) + 1;
1988   default:
1989     report_fatal_error("Unknown alignment");
1990   }
1991 }
1992 
getVISA_Align(unsigned LogAlignment,unsigned GRFWidth)1993 VISA_Align genx::getVISA_Align(unsigned LogAlignment, unsigned GRFWidth) {
1994   if (LogAlignment == Log2_32(ByteBytes))
1995     return ALIGN_BYTE;
1996   else if (LogAlignment == Log2_32(WordBytes))
1997     return ALIGN_WORD;
1998   else if (LogAlignment == Log2_32(DWordBytes))
1999     return ALIGN_DWORD;
2000   else if (LogAlignment == Log2_32(QWordBytes))
2001     return ALIGN_QWORD;
2002   else if (LogAlignment == Log2_32(OWordBytes))
2003     return ALIGN_OWORD;
2004   else if (LogAlignment == Log2_32(GRFWidth))
2005     return ALIGN_GRF;
2006   else if (LogAlignment == Log2_32(GRFWidth) + 1)
2007     return ALIGN_2_GRF;
2008   else
2009     report_fatal_error("Unknown log alignment");
2010 }
2011 
ceilLogAlignment(unsigned LogAlignment,unsigned GRFWidth)2012 unsigned genx::ceilLogAlignment(unsigned LogAlignment, unsigned GRFWidth) {
2013   if (LogAlignment <= Log2_32(ByteBytes))
2014     return Log2_32(ByteBytes);
2015   else if (LogAlignment <= Log2_32(WordBytes))
2016     return Log2_32(WordBytes);
2017   else if (LogAlignment <= Log2_32(DWordBytes))
2018     return Log2_32(DWordBytes);
2019   else if (LogAlignment <= Log2_32(QWordBytes))
2020     return Log2_32(QWordBytes);
2021   else if (LogAlignment <= Log2_32(OWordBytes))
2022     return Log2_32(OWordBytes);
2023   else if (LogAlignment <= Log2_32(GRFWidth))
2024     return Log2_32(GRFWidth);
2025   else if (LogAlignment <= Log2_32(GRFWidth) + 1)
2026     return Log2_32(GRFWidth) + 1;
2027   else
2028     report_fatal_error("Unknown log alignment");
2029 }
2030 
isWrPredRegionLegalSetP(const CallInst & WrPredRegion)2031 bool genx::isWrPredRegionLegalSetP(const CallInst &WrPredRegion) {
2032   IGC_ASSERT_MESSAGE(GenXIntrinsic::getGenXIntrinsicID(&WrPredRegion) == GenXIntrinsic::genx_wrpredregion,
2033     "wrong argument: wrpredregion intrinsic was expected");
2034   auto &NewValue = *WrPredRegion.getOperand(WrPredRegionOperand::NewValue);
2035   auto ExecSize =
2036       NewValue.getType()->isVectorTy()
2037           ? cast<IGCLLVM::FixedVectorType>(NewValue.getType())->getNumElements()
2038           : 1;
2039   auto Offset =
2040       cast<ConstantInt>(WrPredRegion.getOperand(WrPredRegionOperand::Offset))
2041           ->getZExtValue();
2042   if (ExecSize >= 32 || !isPowerOf2_64(ExecSize))
2043     return false;
2044   if (ExecSize == 32)
2045     return Offset == 0;
2046   return Offset == 0 || Offset == 16;
2047 }
2048 
checkFunctionCall(Value * V,Function * F)2049 CallInst *genx::checkFunctionCall(Value *V, Function *F) {
2050   if (!V || !F)
2051     return nullptr;
2052   auto *CI = dyn_cast<CallInst>(V);
2053   if (CI && CI->getCalledFunction() == F)
2054     return CI;
2055   return nullptr;
2056 }
2057 
getNumGRFsPerIndirectForRegion(const genx::Region & R,const GenXSubtarget * ST,bool Allow2D)2058 unsigned genx::getNumGRFsPerIndirectForRegion(const genx::Region &R,
2059                                               const GenXSubtarget *ST,
2060                                               bool Allow2D) {
2061   IGC_ASSERT_MESSAGE(R.Indirect, "Indirect region expected");
2062   IGC_ASSERT(ST);
2063   if (ST->hasIndirectGRFCrossing() &&
2064       // SKL+. See if we can allow GRF crossing.
2065       (Allow2D || !R.is2D())) {
2066     return 2;
2067   }
2068   return 1;
2069 }
2070 
isRealGlobalVariable(const GlobalVariable & GV)2071 bool genx::isRealGlobalVariable(const GlobalVariable &GV) {
2072   if (GV.hasAttribute("genx_volatile"))
2073     return false;
2074   if (GV.hasAttribute(genx::VariableMD::VCPredefinedVariable))
2075     return false;
2076   bool IsIndexedString =
2077       std::any_of(GV.user_begin(), GV.user_end(), [](const User *Usr) {
2078         return vc::isLegalPrintFormatIndexGEP(*Usr);
2079       });
2080   if (IsIndexedString) {
2081     IGC_ASSERT_MESSAGE(std::all_of(GV.user_begin(), GV.user_end(),
2082                                    [](const User *Usr) {
2083                                      return vc::isLegalPrintFormatIndexGEP(
2084                                          *Usr);
2085                                    }),
2086                        "when global is an indexed string, its users can only "
2087                        "be print format index GEPs");
2088     return false;
2089   }
2090   return true;
2091 }
2092 
getStructElementPaddedSize(unsigned ElemIdx,unsigned NumOperands,const StructLayout & Layout)2093 std::size_t genx::getStructElementPaddedSize(unsigned ElemIdx,
2094                                              unsigned NumOperands,
2095                                              const StructLayout &Layout) {
2096   IGC_ASSERT_MESSAGE(ElemIdx < NumOperands,
2097                      "wrong argument: invalid index into a struct");
2098   if (ElemIdx == NumOperands - 1)
2099     return Layout.getSizeInBytes() - Layout.getElementOffset(ElemIdx);
2100   return Layout.getElementOffset(ElemIdx + 1) -
2101          Layout.getElementOffset(ElemIdx);
2102 }
2103 
2104 // splitStructPhi : split a phi node with struct type by splitting into
2105 //                  struct elements
splitStructPhi(PHINode * Phi)2106 bool genx::splitStructPhi(PHINode *Phi) {
2107   StructType *Ty = cast<StructType>(Phi->getType());
2108   // Find where we need to insert the combine instructions.
2109   Instruction *CombineInsertBefore = Phi->getParent()->getFirstNonPHI();
2110   // Now split the phi.
2111   Value *Combined = UndefValue::get(Ty);
2112   // For each struct element...
2113   for (unsigned Idx = 0, e = Ty->getNumElements(); Idx != e; ++Idx) {
2114     Type *ElTy = Ty->getTypeAtIndex(Idx);
2115     // Create the new phi node.
2116     PHINode *NewPhi =
2117         PHINode::Create(ElTy, Phi->getNumIncomingValues(),
2118                         Phi->getName() + ".element" + Twine(Idx), Phi);
2119     NewPhi->setDebugLoc(Phi->getDebugLoc());
2120     // Combine the new phi.
2121     Instruction *Combine = InsertValueInst::Create(
2122         Combined, NewPhi, Idx, NewPhi->getName(), CombineInsertBefore);
2123     Combine->setDebugLoc(Phi->getDebugLoc());
2124     Combined = Combine;
2125     // For each incoming...
2126     for (unsigned In = 0, InEnd = Phi->getNumIncomingValues(); In != InEnd;
2127          ++In) {
2128       // Create an extractelement to get the individual element value.
2129       // This needs to go before the terminator of the incoming block.
2130       BasicBlock *IncomingBB = Phi->getIncomingBlock(In);
2131       Value *Incoming = Phi->getIncomingValue(In);
2132       Instruction *Extract = ExtractValueInst::Create(
2133           Incoming, Idx, Phi->getName() + ".element" + Twine(Idx),
2134           IncomingBB->getTerminator());
2135       Extract->setDebugLoc(Phi->getDebugLoc());
2136       // Add as an incoming of the new phi node.
2137       NewPhi->addIncoming(Extract, IncomingBB);
2138     }
2139   }
2140   Phi->replaceAllUsesWith(Combined);
2141   Phi->eraseFromParent();
2142   return true;
2143 }
2144 
splitStructPhis(Function * F)2145 bool genx::splitStructPhis(Function *F) {
2146   bool Modified = false;
2147   for (Function::iterator fi = F->begin(), fe = F->end(); fi != fe; ++fi) {
2148     BasicBlock *BB = &*fi;
2149     for (BasicBlock::iterator bi = BB->begin();;) {
2150       PHINode *Phi = dyn_cast<PHINode>(&*bi);
2151       if (!Phi)
2152         break;
2153       ++bi; // increment here as splitStructPhi removes old phi node
2154       if (isa<StructType>(Phi->getType()))
2155         Modified |= splitStructPhi(Phi);
2156     }
2157   }
2158   return Modified;
2159 }
2160 
hasMemoryDeps(Instruction * L1,Instruction * L2,Value * Addr,DominatorTree * DT)2161 bool genx::hasMemoryDeps(Instruction *L1, Instruction *L2, Value *Addr,
2162                          DominatorTree *DT) {
2163   // Return false for non global loads
2164   if (!(GenXIntrinsic::isVLoad(L1) && GenXIntrinsic::isVLoad(L2)) &&
2165       !(isGlobalLoad(L1) && isGlobalLoad(L2)))
2166     return false;
2167 
2168   auto isKill = [=](Instruction &I) {
2169     Instruction *Inst = &I;
2170     if ((GenXIntrinsic::isVStore(Inst) || genx::isGlobalStore(Inst)) &&
2171         (Addr == Inst->getOperand(1) ||
2172          Addr == getUnderlyingGlobalVariable(Inst->getOperand(1))))
2173       return true;
2174     // OK.
2175     return false;
2176   };
2177 
2178   // global loads from the same block.
2179   if (L1->getParent() == L2->getParent()) {
2180     BasicBlock::iterator I = L1->getParent()->begin();
2181     for (; &*I != L1 && &*I != L2; ++I)
2182       /*empty*/;
2183     IGC_ASSERT(&*I == L1 || &*I == L2);
2184     auto IEnd = (&*I == L1) ? L2->getIterator() : L1->getIterator();
2185     return std::any_of(I->getIterator(), IEnd, isKill);
2186   }
2187 
2188   // global loads are from different blocks.
2189   //
2190   //       BB1 (L1)
2191   //      /   \
2192   //   BB3    BB2 (L2)
2193   //     \     /
2194   //       BB4
2195   //
2196   auto BB1 = L1->getParent();
2197   auto BB2 = L2->getParent();
2198   if (!DT->properlyDominates(BB1, BB2)) {
2199     std::swap(BB1, BB2);
2200     std::swap(L1, L2);
2201   }
2202   if (DT->properlyDominates(BB1, BB2)) {
2203     // As BB1 dominates BB2, we can recursively check BB2's predecessors, until
2204     // reaching BB1.
2205     //
2206     // check BB1 && BB2
2207     if (std::any_of(BB2->begin(), L2->getIterator(), isKill))
2208       return true;
2209     if (std::any_of(L1->getIterator(), BB1->end(), isKill))
2210       return true;
2211     std::set<BasicBlock *> Visited{BB1, BB2};
2212     std::vector<BasicBlock *> BBs;
2213     std::copy_if(pred_begin(BB2), pred_end(BB2), std::back_inserter(BBs),
2214                  [Visited](BasicBlock *BB) { return !Visited.count(BB); });
2215 
2216     // This visits the subgraph dominated by BB1, originated from BB2.
2217     while (!BBs.empty()) {
2218       BasicBlock *BB = BBs.back();
2219       BBs.pop_back();
2220       Visited.insert(BB);
2221 
2222       // check if there is any store kill in this block.
2223       if (std::any_of(BB->begin(), BB->end(), isKill))
2224         return true;
2225 
2226       // Populate not visited predecessors.
2227       std::copy_if(pred_begin(BB), pred_end(BB), std::back_inserter(BBs),
2228                    [Visited](BasicBlock *BB) { return !Visited.count(BB); });
2229     }
2230 
2231     // no mem deps.
2232     return false;
2233   }
2234 
2235   return true;
2236 }
2237 
isRdRFromGlobalLoad(Value * V)2238 bool genx::isRdRFromGlobalLoad(Value *V) {
2239   if (!GenXIntrinsic::isRdRegion(V))
2240     return false;
2241   auto *RdR = cast<CallInst>(V);
2242   auto *I = dyn_cast<Instruction>(
2243       RdR->getArgOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum));
2244   return I && isGlobalLoad(I);
2245 };
2246 
isWrRToGlobalLoad(Value * V)2247 bool genx::isWrRToGlobalLoad(Value *V) {
2248   if (!GenXIntrinsic::isWrRegion(V))
2249     return false;
2250   auto *WrR = cast<CallInst>(V);
2251   auto *I = dyn_cast<Instruction>(
2252       WrR->getArgOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum));
2253   return I && isGlobalLoad(I);
2254 };
2255