1 /*========================== begin_copyright_notice ============================
2
3 Copyright (C) 2017-2021 Intel Corporation
4
5 SPDX-License-Identifier: MIT
6
7 ============================= end_copyright_notice ===========================*/
8
9 //
10 // Utility functions for the GenX backend.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "GenXUtil.h"
14 #include "FunctionGroup.h"
15 #include "GenX.h"
16 #include "GenXIntrinsics.h"
17
18 #include "vc/GenXOpts/Utils/InternalMetadata.h"
19 #include "vc/Utils/GenX/Printf.h"
20 #include "vc/Utils/General/Types.h"
21
22 #include "llvm/ADT/MapVector.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/GenXIntrinsics/GenXIntrinsics.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/Metadata.h"
35 #include "llvm/IR/Module.h"
36
37 #include "llvmWrapper/IR/DerivedTypes.h"
38 #include "llvmWrapper/IR/InstrTypes.h"
39 #include "llvmWrapper/IR/Instructions.h"
40
41 #include "Probe/Assertion.h"
42 #include <cstddef>
43 #include <iterator>
44
45 using namespace llvm;
46 using namespace genx;
47
48 namespace {
49 struct InstScanner {
50 Instruction *Original;
51 Instruction *Current;
InstScanner__anonc37a58090111::InstScanner52 InstScanner(Instruction *Inst) : Original(Inst), Current(Inst) {}
53 };
54
55 } // namespace
56
57 /***********************************************************************
58 * createConvert : create a genx_convert intrinsic call
59 *
60 * Enter: In = value to convert
61 * Name = name to give convert instruction
62 * InsertBefore = instruction to insert before else 0
63 * M = Module (can be 0 as long as InsertBefore is not 0)
64 */
createConvert(Value * In,const Twine & Name,Instruction * InsertBefore,Module * M)65 CallInst *genx::createConvert(Value *In, const Twine &Name,
66 Instruction *InsertBefore, Module *M)
67 {
68 if (!M)
69 M = InsertBefore->getParent()->getParent()->getParent();
70 Function *Decl = GenXIntrinsic::getGenXDeclaration(M, GenXIntrinsic::genx_convert,
71 In->getType());
72 return CallInst::Create(Decl, In, Name, InsertBefore);
73 }
74
75 /***********************************************************************
76 * createConvertAddr : create a genx_convert_addr intrinsic call
77 *
78 * Enter: In = value to convert
79 * Offset = constant offset
80 * Name = name to give convert instruction
81 * InsertBefore = instruction to insert before else 0
82 * M = Module (can be 0 as long as InsertBefore is not 0)
83 */
createConvertAddr(Value * In,int Offset,const Twine & Name,Instruction * InsertBefore,Module * M)84 CallInst *genx::createConvertAddr(Value *In, int Offset, const Twine &Name,
85 Instruction *InsertBefore, Module *M)
86 {
87 if (!M)
88 M = InsertBefore->getParent()->getParent()->getParent();
89 auto OffsetVal = ConstantInt::get(In->getType()->getScalarType(), Offset);
90 Function *Decl = GenXIntrinsic::getGenXDeclaration(M, GenXIntrinsic::genx_convert_addr,
91 In->getType());
92 Value *Args[] = { In, OffsetVal };
93 return CallInst::Create(Decl, Args, Name, InsertBefore);
94 }
95
96 /***********************************************************************
97 * createAddAddr : create a genx_add_addr intrinsic call
98 *
99 * InsertBefore can be 0 so the new instruction is not inserted anywhere,
100 * but in that case M must be non-0 and set to the Module.
101 */
createAddAddr(Value * Lhs,Value * Rhs,const Twine & Name,Instruction * InsertBefore,Module * M)102 CallInst *genx::createAddAddr(Value *Lhs, Value *Rhs, const Twine &Name,
103 Instruction *InsertBefore, Module *M)
104 {
105 if (!M)
106 M = InsertBefore->getParent()->getParent()->getParent();
107 Value *Args[] = {Lhs, Rhs};
108 Type *Tys[] = {Rhs->getType(), Lhs->getType()};
109 Function *Decl = GenXIntrinsic::getGenXDeclaration(M, GenXIntrinsic::genx_add_addr, Tys);
110 return CallInst::Create(Decl, Args, Name, InsertBefore);
111 }
112
113 /***********************************************************************
114 * createUnifiedRet : create a dummy instruction that produces dummy
115 * unified return value.
116 *
117 * %Name.unifiedret = call Ty @llvm.ssa_copy(Ty undef)
118 */
createUnifiedRet(Type * Ty,const Twine & Name,Module * M)119 CallInst *genx::createUnifiedRet(Type *Ty, const Twine &Name, Module *M) {
120 IGC_ASSERT_MESSAGE(Ty, "wrong argument");
121 IGC_ASSERT_MESSAGE(M, "wrong argument");
122 auto G = Intrinsic::getDeclaration(M, Intrinsic::ssa_copy, Ty);
123 return CallInst::Create(G, UndefValue::get(Ty), Name + ".unifiedret",
124 static_cast<Instruction *>(nullptr));
125 }
126
127 /***********************************************************************
128 * getPredicateConstantAsInt : get an i1 or vXi1 constant's value as a single
129 * integer
130 *
131 * Elements of constant \p C are encoded as least significant bits of the
132 * result. For scalar case only LSB of the result is set to corresponding value.
133 */
getPredicateConstantAsInt(const Constant * C)134 unsigned genx::getPredicateConstantAsInt(const Constant *C) {
135 IGC_ASSERT_MESSAGE(C->getType()->isIntOrIntVectorTy(1),
136 "wrong argument: constant of i1 or Nxi1 type was expected");
137 if (auto CI = dyn_cast<ConstantInt>(C))
138 return CI->getZExtValue(); // scalar
139 unsigned Bits = 0;
140 unsigned NumElements =
141 cast<IGCLLVM::FixedVectorType>(C->getType())->getNumElements();
142 IGC_ASSERT_MESSAGE(NumElements <= sizeof(Bits) * CHAR_BIT,
143 "vector has too much elements, it won't fit into Bits");
144 for (unsigned i = 0; i != NumElements; ++i) {
145 auto El = C->getAggregateElement(i);
146 if (!isa<UndefValue>(El))
147 Bits |= (cast<ConstantInt>(El)->getZExtValue() & 1) << i;
148 }
149 return Bits;
150 }
151
152 /***********************************************************************
153 * getConstantSubvector : get a contiguous region from a vector constant
154 */
getConstantSubvector(const Constant * V,unsigned StartIdx,unsigned Size)155 Constant *genx::getConstantSubvector(const Constant *V, unsigned StartIdx,
156 unsigned Size) {
157 Type *ElTy = cast<VectorType>(V->getType())->getElementType();
158 Type *RegionTy = IGCLLVM::FixedVectorType::get(ElTy, Size);
159 Constant *SubVec = nullptr;
160 if (isa<UndefValue>(V))
161 SubVec = UndefValue::get(RegionTy);
162 else if (isa<ConstantAggregateZero>(V))
163 SubVec = ConstantAggregateZero::get(RegionTy);
164 else {
165 SmallVector<Constant *, 32> Val;
166 for (unsigned i = 0; i != Size; ++i)
167 Val.push_back(V->getAggregateElement(i + StartIdx));
168 SubVec = ConstantVector::get(Val);
169 }
170 return SubVec;
171 }
172
173 /***********************************************************************
174 * concatConstants : concatenate two possibly vector constants, giving a
175 * vector constant
176 */
concatConstants(Constant * C1,Constant * C2)177 Constant *genx::concatConstants(Constant *C1, Constant *C2)
178 {
179 IGC_ASSERT(C1->getType()->getScalarType() == C2->getType()->getScalarType());
180 Constant *CC[] = { C1, C2 };
181 SmallVector<Constant *, 8> Vec;
182 bool AllUndef = true;
183 for (unsigned Idx = 0; Idx != 2; ++Idx) {
184 Constant *C = CC[Idx];
185 if (auto *VT = dyn_cast<IGCLLVM::FixedVectorType>(C->getType())) {
186 for (unsigned i = 0, e = VT->getNumElements(); i != e; ++i) {
187 Constant *El = C->getAggregateElement(i);
188 Vec.push_back(El);
189 AllUndef &= isa<UndefValue>(El);
190 }
191 } else {
192 Vec.push_back(C);
193 AllUndef &= isa<UndefValue>(C);
194 }
195 }
196 auto Res = ConstantVector::get(Vec);
197 if (AllUndef)
198 Res = UndefValue::get(Res->getType());
199 return Res;
200 }
201
202 /***********************************************************************
203 * findClosestCommonDominator : find closest common dominator of some instructions
204 *
205 * Enter: DT = dominator tree
206 * Insts = the instructions
207 *
208 * Return: The one instruction that dominates all the others, if any.
209 * Otherwise the terminator of the closest common dominating basic
210 * block.
211 */
findClosestCommonDominator(DominatorTree * DT,ArrayRef<Instruction * > Insts)212 Instruction *genx::findClosestCommonDominator(DominatorTree *DT,
213 ArrayRef<Instruction *> Insts)
214 {
215 IGC_ASSERT(!Insts.empty());
216 SmallVector<InstScanner, 8> InstScanners;
217 // Find the closest common dominating basic block.
218 Instruction *Inst0 = Insts[0];
219 BasicBlock *NCD = Inst0->getParent();
220 InstScanners.push_back(InstScanner(Inst0));
221 for (unsigned ii = 1, ie = Insts.size(); ii != ie; ++ii) {
222 Instruction *Inst = Insts[ii];
223 if (Inst->getParent() != NCD) {
224 auto NewNCD = DT->findNearestCommonDominator(NCD, Inst->getParent());
225 if (NewNCD != NCD)
226 InstScanners.clear();
227 NCD = NewNCD;
228 }
229 if (NCD == Inst->getParent())
230 InstScanners.push_back(Inst);
231 }
232 // Now we have NCD = the closest common dominating basic block, and
233 // InstScanners populated with the instructions from Insts that are
234 // in that block.
235 if (InstScanners.empty()) {
236 // No instructions in that block. Return the block's terminator.
237 return NCD->getTerminator();
238 }
239 if (InstScanners.size() == 1) {
240 // Only one instruction in that block. Return it.
241 return InstScanners[0].Original;
242 }
243 // Create a set of the original instructions.
244 std::set<Instruction *> OrigInsts;
245 for (auto i = InstScanners.begin(), e = InstScanners.end(); i != e; ++i)
246 OrigInsts.insert(i->Original);
247 // Scan back one instruction at a time for each scanner. If a scanner reaches
248 // another original instruction, the scanner can be removed, and when we are
249 // left with one scanner, that must be the earliest of the original
250 // instructions. If a scanner reaches the start of the basic block, that was
251 // the earliest of the original instructions.
252 //
253 // In the worst case, this algorithm could scan all the instructions in a
254 // basic block, but it is designed to be better than that in the common case
255 // that the original instructions are close to each other.
256 for (;;) {
257 for (auto i = InstScanners.begin(), e = InstScanners.end(); i != e; ++i) {
258 if (i->Current == &i->Current->getParent()->front())
259 return i->Original; // reached start of basic block
260 i->Current = i->Current->getPrevNode();
261 if (OrigInsts.find(i->Current) != OrigInsts.end()) {
262 // Scanned back to another instruction in our original set. Remove
263 // this scanner.
264 *i = InstScanners.back();
265 InstScanners.pop_back();
266 if (InstScanners.size() == 1)
267 return InstScanners[0].Original; // only one scanner left
268 break; // restart loop so as not to confuse the iterator
269 }
270 }
271 }
272 }
273
274 /***********************************************************************
275 * getTwoAddressOperandNum : get operand number of two address operand
276 *
277 * If an intrinsic has a "two address operand", then that operand must be
278 * in the same register as the result. This function returns the operand number
279 * of the two address operand if any, or None if not.
280 */
getTwoAddressOperandNum(CallInst * CI)281 llvm::Optional<unsigned> genx::getTwoAddressOperandNum(CallInst *CI)
282 {
283 auto IntrinsicID = GenXIntrinsic::getAnyIntrinsicID(CI);
284 if (IntrinsicID == GenXIntrinsic::not_any_intrinsic)
285 return None; // not intrinsic
286 // wr(pred(pred))region has operand 0 as two address operand
287 if (GenXIntrinsic::isWrRegion(IntrinsicID) ||
288 IntrinsicID == GenXIntrinsic::genx_wrpredregion ||
289 IntrinsicID == GenXIntrinsic::genx_wrpredpredregion)
290 return GenXIntrinsic::GenXRegion::OldValueOperandNum;
291 if (CI->getType()->isVoidTy())
292 return None; // no return value
293 GenXIntrinsicInfo II(IntrinsicID);
294 unsigned Num = CI->getNumArgOperands();
295 if (!Num)
296 return None; // no args
297 --Num; // Num = last arg number, could be two address operand
298 if (isa<UndefValue>(CI->getOperand(Num)))
299 return None; // operand is undef, must be RAW_NULLALLOWED
300 if (II.getArgInfo(Num).getCategory() != GenXIntrinsicInfo::TWOADDR)
301 return None; // not two addr operand
302 if (CI->use_empty() && II.getRetInfo().rawNullAllowed())
303 return None; // unused result will be V0
304 return Num; // it is two addr
305 }
306
307 /***********************************************************************
308 * isNot : test whether an instruction is a "not" instruction (an xor with
309 * constant all ones)
310 */
isNot(Instruction * Inst)311 bool genx::isNot(Instruction *Inst)
312 {
313 if (Inst->getOpcode() == Instruction::Xor)
314 if (auto C = dyn_cast<Constant>(Inst->getOperand(1)))
315 if (C->isAllOnesValue())
316 return true;
317 return false;
318 }
319
320 /***********************************************************************
321 * isPredNot : test whether an instruction is a "not" instruction (an xor
322 * with constant all ones) with predicate (i1 or vector of i1) type
323 */
isPredNot(Instruction * Inst)324 bool genx::isPredNot(Instruction *Inst)
325 {
326 if (Inst->getOpcode() == Instruction::Xor)
327 if (auto C = dyn_cast<Constant>(Inst->getOperand(1)))
328 if (C->isAllOnesValue() && C->getType()->getScalarType()->isIntegerTy(1))
329 return true;
330 return false;
331 }
332
333 /***********************************************************************
334 * isIntNot : test whether an instruction is a "not" instruction (an xor
335 * with constant all ones) with non-predicate type
336 */
isIntNot(Instruction * Inst)337 bool genx::isIntNot(Instruction *Inst)
338 {
339 if (Inst->getOpcode() == Instruction::Xor)
340 if (auto C = dyn_cast<Constant>(Inst->getOperand(1)))
341 if (C->isAllOnesValue() && !C->getType()->getScalarType()->isIntegerTy(1))
342 return true;
343 return false;
344 }
345
346 /***********************************************************************
347 * invertCondition : Invert the given predicate value, possibly reusing
348 * an existing copy.
349 */
invertCondition(Value * Condition)350 Value *genx::invertCondition(Value *Condition)
351 {
352 IGC_ASSERT_MESSAGE(Condition->getType()->getScalarType()->isIntegerTy(1),
353 "Condition is not of predicate type");
354 // First: Check if it's a constant.
355 if (Constant *C = dyn_cast<Constant>(Condition))
356 return ConstantExpr::getNot(C);
357
358 // Second: If the condition is already inverted, return the original value.
359 Instruction *Inst = dyn_cast<Instruction>(Condition);
360 if (Inst && isPredNot(Inst))
361 return Inst->getOperand(0);
362
363 // Last option: Create a new instruction.
364 auto *Inverted =
365 BinaryOperator::CreateNot(Condition, Condition->getName() + ".inv");
366 if (Inst && !isa<PHINode>(Inst))
367 Inverted->insertAfter(Inst);
368 else {
369 BasicBlock *Parent = nullptr;
370 if (Inst)
371 Parent = Inst->getParent();
372 else if (Argument *Arg = dyn_cast<Argument>(Condition))
373 Parent = &Arg->getParent()->getEntryBlock();
374 IGC_ASSERT_MESSAGE(Parent, "Unsupported condition to invert");
375 Inverted->insertBefore(&*Parent->getFirstInsertionPt());
376 }
377 return Inverted;
378 }
379
380 /***********************************************************************
381 * isNoopCast : test if cast operation doesn't modify bitwise representation
382 * of value (in other words, it can be copy-coalesced).
383 * NOTE: LLVM has CastInst::isNoopCast method, but it conservatively treats
384 * AddrSpaceCast as modifying operation; this function can be more aggresive
385 * relying on DataLayout information.
386 */
isNoopCast(const CastInst * CI)387 bool genx::isNoopCast(const CastInst *CI) {
388 const DataLayout &DL = CI->getModule()->getDataLayout();
389 switch (CI->getOpcode()) {
390 case Instruction::BitCast:
391 return true;
392 case Instruction::PtrToInt:
393 case Instruction::IntToPtr:
394 case Instruction::AddrSpaceCast:
395 return vc::getTypeSize(CI->getDestTy(), &DL) ==
396 vc::getTypeSize(CI->getSrcTy(), &DL);
397 default:
398 return false;
399 }
400 }
401
402 /***********************************************************************
403 * ShuffleVectorAnalyzer::getAsSlice : see if the shufflevector is a slice on
404 * operand 0, and if so return the start index, or -1 if it is not a slice
405 */
getAsSlice()406 int ShuffleVectorAnalyzer::getAsSlice()
407 {
408 unsigned WholeWidth =
409 cast<IGCLLVM::FixedVectorType>(SI->getOperand(0)->getType())
410 ->getNumElements();
411 Constant *Selector = IGCLLVM::getShuffleMaskForBitcode(SI);
412 unsigned Width =
413 cast<IGCLLVM::FixedVectorType>(SI->getType())->getNumElements();
414 auto *Aggr = Selector->getAggregateElement(0u);
415 if (isa<UndefValue>(Aggr))
416 return -1; // operand 0 is undef value
417 unsigned StartIdx = cast<ConstantInt>(Aggr)->getZExtValue();
418 if (StartIdx >= WholeWidth)
419 return -1; // start index beyond operand 0
420 unsigned SliceWidth;
421 for (SliceWidth = 1; SliceWidth != Width; ++SliceWidth) {
422 auto CI = dyn_cast<ConstantInt>(Selector->getAggregateElement(SliceWidth));
423 if (!CI)
424 break;
425 if (CI->getZExtValue() != StartIdx + SliceWidth)
426 return -1; // not slice
427 }
428 return StartIdx;
429 }
430
431 /***********************************************************************
432 * ShuffleVectorAnalyzer::isReplicatedSlice : check if the shufflevector
433 * is a replicated slice on operand 0.
434 */
isReplicatedSlice() const435 bool ShuffleVectorAnalyzer::isReplicatedSlice() const {
436 const auto MaskVals = SI->getShuffleMask();
437 auto Begin = MaskVals.begin();
438 auto End = MaskVals.end();
439
440 // Check for undefs.
441 if (std::find(Begin, End, -1) != End)
442 return false;
443
444 if (MaskVals.size() == 1)
445 return true;
446
447 // Slice should not touch second operand.
448 auto MaxIndex = static_cast<size_t>(MaskVals.back());
449 if (MaxIndex >= cast<IGCLLVM::FixedVectorType>(SI->getOperand(0)->getType())
450 ->getNumElements())
451 return false;
452
453 // Find first non-one difference.
454 auto SliceEnd =
455 std::adjacent_find(Begin, End,
456 [](int Prev, int Next) { return Next - Prev != 1; });
457 // If not found, then it is simple slice.
458 if (SliceEnd == End)
459 return true;
460
461 // Compare slice with parts of sequence to prove that it is periodic.
462 ++SliceEnd;
463 unsigned SliceSize = std::distance(Begin, SliceEnd);
464 // Slice should be replicated.
465 if (MaskVals.size() % SliceSize != 0)
466 return false;
467
468 for (auto It = SliceEnd; It != End; std::advance(It, SliceSize))
469 if (!std::equal(Begin, SliceEnd, It))
470 return false;
471
472 return true;
473 }
474
getMaskOperand(const Instruction * Inst)475 Value *genx::getMaskOperand(const Instruction *Inst) {
476 IGC_ASSERT(Inst);
477
478 // return null for any other intrusction except
479 // genx intrinsics
480 auto *CI = dyn_cast<CallInst>(Inst);
481 if (!CI || !GenXIntrinsic::isGenXIntrinsic(CI))
482 return nullptr;
483
484 auto MaskOpIt = llvm::find_if(CI->operands(), [](const Use &U) {
485 Value *Operand = U.get();
486 if (auto *VT = dyn_cast<VectorType>(Operand->getType()))
487 return VT->getElementType()->isIntegerTy(1);
488 return false;
489 });
490
491 // No mask among opernads
492 if (MaskOpIt == CI->op_end())
493 return nullptr;
494
495 return *MaskOpIt;
496 }
497
498 // Based on the value of a shufflevector mask element defines in which of
499 // 2 operands it points. The operand is returned.
getOperandByMaskValue(const ShuffleVectorInst & SI,int MaskValue)500 static Value *getOperandByMaskValue(const ShuffleVectorInst &SI,
501 int MaskValue) {
502 IGC_ASSERT_MESSAGE(MaskValue >= 0, "invalid index");
503 int FirstOpSize = cast<IGCLLVM::FixedVectorType>(SI.getOperand(0)->getType())
504 ->getNumElements();
505 if (MaskValue < FirstOpSize)
506 return SI.getOperand(0);
507 else {
508 int SecondOpSize =
509 cast<IGCLLVM::FixedVectorType>(SI.getOperand(1)->getType())
510 ->getNumElements();
511 IGC_ASSERT_MESSAGE(MaskValue < FirstOpSize + SecondOpSize, "invalid index");
512 return SI.getOperand(1);
513 }
514 }
515
516 // safe advance
517 // If adding \p N results in bound violation, \p Last is written to \p It
advanceSafe(Iter & It,Iter Last,int N)518 template <typename Iter> void advanceSafe(Iter &It, Iter Last, int N) {
519 if (N > std::distance(It, Last)) {
520 It = Last;
521 return;
522 }
523 std::advance(It, N);
524 }
525
526 // Returns operand and its region of 1 element that is referenced by
527 // \p MaskVal element of shufflevector mask.
528 static ShuffleVectorAnalyzer::OperandRegionInfo
matchOneElemRegion(const ShuffleVectorInst & SI,int MaskVal)529 matchOneElemRegion(const ShuffleVectorInst &SI, int MaskVal) {
530 ShuffleVectorAnalyzer::OperandRegionInfo Init;
531 Init.Op = getOperandByMaskValue(SI, MaskVal);
532 Init.R = Region(Init.Op);
533 Init.R.NumElements = Init.R.Width = 1;
534 if (Init.Op == SI.getOperand(0))
535 Init.R.Offset = MaskVal * Init.R.ElementBytes;
536 else {
537 auto FirstOpSize =
538 cast<IGCLLVM::FixedVectorType>(SI.getOperand(0)->getType())
539 ->getNumElements();
540 Init.R.Offset = (MaskVal - FirstOpSize) * Init.R.ElementBytes;
541 }
542 return Init;
543 }
544
545 class MaskIndex {
546 int Idx;
547 static constexpr const int Undef = -1;
548 static constexpr const int AnotherOp = -2;
549
550 public:
MaskIndex(int InitIdx=0)551 explicit MaskIndex(int InitIdx = 0) : Idx(InitIdx) {
552 IGC_ASSERT_MESSAGE(Idx >= 0, "Defined index must not be negative");
553 }
554
getUndef()555 static MaskIndex getUndef() {
556 MaskIndex Ret;
557 Ret.Idx = Undef;
558 return Ret;
559 }
560
getAnotherOp()561 static MaskIndex getAnotherOp() {
562 MaskIndex Ret;
563 Ret.Idx = AnotherOp;
564 return Ret;
565 }
isUndef() const566 bool isUndef() const { return Idx == Undef; }
isAnotherOp() const567 bool isAnotherOp() const { return Idx == AnotherOp; }
isDefined() const568 bool isDefined() const { return Idx >= 0; }
569
get() const570 int get() const {
571 IGC_ASSERT_MESSAGE(Idx >= 0, "Can't call get() on invalid index");
572 return Idx;
573 }
574
operator -(MaskIndex const & rhs) const575 int operator-(MaskIndex const &rhs) const {
576 IGC_ASSERT_MESSAGE(isDefined(), "All operand indices must be valid");
577 IGC_ASSERT_MESSAGE(rhs.isDefined(), "All operand indices must be valid");
578 return Idx - rhs.Idx;
579 }
580 };
581
582 // Takes shufflevector mask indexes from [\p FirstIt, \p LastIt),
583 // converts them to the indices of \p Operand of \p SI instruction
584 // and writes them to \p OutIt. Value type of OutIt is MaskIndex.
585 template <typename ForwardIter, typename OutputIter>
makeSVIIndexesOperandIndexes(const ShuffleVectorInst & SI,const Value & Operand,ForwardIter FirstIt,ForwardIter LastIt,OutputIter OutIt)586 void makeSVIIndexesOperandIndexes(const ShuffleVectorInst &SI,
587 const Value &Operand, ForwardIter FirstIt,
588 ForwardIter LastIt, OutputIter OutIt) {
589 int FirstOpSize = cast<IGCLLVM::FixedVectorType>(SI.getOperand(0)->getType())
590 ->getNumElements();
591 if (&Operand == SI.getOperand(0)) {
592 std::transform(FirstIt, LastIt, OutIt, [FirstOpSize](int MaskVal) {
593 if (MaskVal >= FirstOpSize)
594 return MaskIndex::getAnotherOp();
595 return MaskVal < 0 ? MaskIndex::getUndef() : MaskIndex{MaskVal};
596 });
597 return;
598 }
599 IGC_ASSERT_MESSAGE(&Operand == SI.getOperand(1),
600 "wrong argument: a shufflevector operand was expected");
601 std::transform(FirstIt, LastIt, OutIt, [FirstOpSize](int MaskVal) {
602 if (MaskVal < 0)
603 return MaskIndex::getUndef();
604 return MaskVal >= FirstOpSize ? MaskIndex{MaskVal - FirstOpSize}
605 : MaskIndex::getAnotherOp();
606 });
607 }
608
609 // Calculates horisontal stride for region by scanning mask indices in
610 // range [\p FirstIt, \p LastIt).
611 //
612 // Arguments:
613 // [\p FirstIt, \p LastIt) is the range of MaskIndex. There must not be
614 // any AnotherOp indices in the range.
615 // \P FirstIt must point to a defined index.
616 // Return value:
617 // std::pair with first element to be Iterator to next defined element
618 // or std::next(FirstIt) if there is no such one and second element to
619 // be estimated stride if positive and integer, empty value otherwise.
620 template <typename ForwardIter>
621 std::pair<ForwardIter, llvm::Optional<int>>
estimateHorizontalStride(ForwardIter FirstIt,ForwardIter LastIt)622 estimateHorizontalStride(ForwardIter FirstIt, ForwardIter LastIt) {
623
624 IGC_ASSERT_MESSAGE(FirstIt != LastIt, "the range must contain at least 1 element");
625 IGC_ASSERT_MESSAGE(std::none_of(FirstIt, LastIt, [](MaskIndex Idx) { return Idx.isAnotherOp(); }),
626 "There must not be any AnotherOp indices in the range");
627 IGC_ASSERT_MESSAGE(FirstIt->isDefined(),
628 "first element in range must be a valid index");
629 auto NextDefined =
630 std::find_if(std::next(FirstIt), LastIt,
631 [](MaskIndex Elem) { return Elem.isDefined(); });
632
633 if (NextDefined == LastIt)
634 return {std::next(FirstIt), llvm::Optional<int>{}};
635
636 int TotalStride = *NextDefined - *FirstIt;
637 int TotalWidth = std::distance(FirstIt, NextDefined);
638
639 if (TotalStride < 0 || (TotalStride % TotalWidth != 0 && TotalStride != 0))
640 return {NextDefined, llvm::Optional<int>{}};
641
642 return {NextDefined, TotalStride / TotalWidth};
643 }
644
645 // Matches "vector" region (with vstride == 0) pattern in
646 // [\p FirstIt, \p LastIt) indexes.
647 // Uses info in \p FirstElemRegion, adds defined Width, Stride and
648 // new NumElements to \p FirstElemRegion and returns resulting region.
649 //
650 // Arguments:
651 // [\p FirstIt, \p LastIt) is the range of MaskIndex. There must not be
652 // any AnotherOp indices in the range.
653 // FirstIt and std::prev(LastIt) must point to a defined indices.
654 // \p FirstElemRegion describes one element region with only one index
655 // *FirstIt.
656 // \p BoundIndex is maximum possible index of the input vector + 1
657 // (BoundIndex == InputVector.length)
658 template <typename ForwardIter>
matchVectorRegionByIndexes(Region FirstElemRegion,ForwardIter FirstIt,ForwardIter LastIt,int BoundIndex)659 Region matchVectorRegionByIndexes(Region FirstElemRegion, ForwardIter FirstIt,
660 ForwardIter LastIt, int BoundIndex) {
661 IGC_ASSERT_MESSAGE(FirstIt != LastIt, "the range must contain at least 1 element");
662 IGC_ASSERT_MESSAGE(std::none_of(FirstIt, LastIt, [](MaskIndex Idx) { return Idx.isAnotherOp(); }),
663 "There must not be any AnotherOp indices in the range.");
664 IGC_ASSERT_MESSAGE(FirstIt->isDefined(),
665 "expected FirstIt and --LastIt point to valid indices");
666 IGC_ASSERT_MESSAGE(std::prev(LastIt)->isDefined(),
667 "expected FirstIt and --LastIt point to valid indices");
668
669 if (std::distance(FirstIt, LastIt) == 1)
670 return FirstElemRegion;
671
672 llvm::Optional<int> RefStride;
673 ForwardIter NewRowIt;
674 std::tie(NewRowIt, RefStride) = estimateHorizontalStride(FirstIt, LastIt);
675
676 if (!RefStride)
677 return FirstElemRegion;
678
679 llvm::Optional<int> Stride = RefStride;
680 while (Stride == RefStride)
681 std::tie(NewRowIt, Stride) = estimateHorizontalStride(NewRowIt, LastIt);
682
683 auto TotalStride = std::distance(FirstIt, std::prev(NewRowIt)) * *RefStride;
684 auto Overstep = TotalStride + FirstIt->get() - BoundIndex + 1;
685 if (Overstep > 0)
686 NewRowIt = std::prev(NewRowIt, llvm::divideCeil(Overstep, *RefStride));
687
688 int Width = std::distance(FirstIt, NewRowIt);
689 IGC_ASSERT_MESSAGE(Width > 0, "should be at least 1 according to algorithm");
690 if (Width == 1)
691 // Stride doesn't play role when the Width is 1.
692 // Also it prevents from writing to big value in the region.
693 RefStride = 0;
694 FirstElemRegion.Stride = *RefStride;
695 FirstElemRegion.Width = Width;
696 FirstElemRegion.NumElements = Width;
697 return FirstElemRegion;
698 }
699
700 // Calculates vertical stride for region by scanning mask indices in
701 // range [\p FirstIt, \p LastIt).
702 //
703 // Arguments:
704 // [\p FirstRowRegion] describes "vector" region (with vstride == 0),
705 // which is formed by first 'FirstRowRegion.NumElements' elements
706 // of the range.
707 // [\p FirstIt] Points to first element of vector of indices.
708 // [\p ReferenceIt] Points to some valid reference element in that vector.
709 // Must be out of range of first 'FirstRowRegion.NumElements' elements.
710 // First 'FirstRowRegion.NumElements' in range must be defined indices.
711 // Return value:
712 // Value of estimated vertical stride if it is positive and integer,
713 // empty value otherwise
714 template <typename ForwardIter>
estimateVerticalStride(Region FirstRowRegion,ForwardIter FirstIt,ForwardIter ReferenceIt)715 llvm::Optional<int> estimateVerticalStride(Region FirstRowRegion,
716 ForwardIter FirstIt,
717 ForwardIter ReferenceIt) {
718
719 IGC_ASSERT_MESSAGE(std::distance(FirstIt, ReferenceIt) >= static_cast<std::ptrdiff_t>(FirstRowRegion.Width),
720 "Reference element must not be part of first row");
721 IGC_ASSERT_MESSAGE(std::all_of(FirstIt, std::next(FirstIt, FirstRowRegion.Width), [](MaskIndex Elem) { return Elem.isDefined(); }),
722 "First row must contain only valid indices");
723 IGC_ASSERT_MESSAGE(ReferenceIt->isDefined(), "Reference index must be valid");
724
725 int Width = FirstRowRegion.Width;
726
727 int TotalDistance = std::distance(FirstIt, ReferenceIt);
728 int VStridesToDef = TotalDistance / Width;
729 int HStridesToDef = TotalDistance % Width;
730 int TotalVerticalStride = *ReferenceIt - *std::next(FirstIt, HStridesToDef);
731 if (TotalVerticalStride < 0 || TotalVerticalStride % VStridesToDef != 0)
732 return llvm::Optional<int>{};
733
734 return llvm::Optional<int>{TotalVerticalStride / VStridesToDef};
735 }
736
737 // Matches "matrix" region (vstride may not equal to 0) pattern in
738 // [\p FirstIt, \p LastIt) index.
739 // Uses info in \p FirstRowRegion, adds defined VStride and new NumElements to
740 // \p FirstRowRegion and returns resulting region.
741 //
742 // Arguments:
743 // [\p FirstIt, \p LastIt) is the range of MaskIndex. Note that this
744 // pass may change the contents of this vector (replace undef indices
745 // with defined ones), so it can affect further usage.
746 // \p LastDefinedIt points to last element in a vector with a defined
747 // index
748 // \p FirstRowRegion describes "vector" region (with vstride == 0),
749 // which is formed by first 'FirstRowRegion.NumElements' elements
750 // of the range.
751 // \p BoundIndex is maximum possible index of the input vector + 1
752 // (BoundIndex == InputVector.length)
753 template <typename ForwardIter>
matchMatrixRegionByIndexes(Region FirstRowRegion,ForwardIter FirstIt,ForwardIter LastIt,ForwardIter LastDefinedIt,int BoundIndex)754 Region matchMatrixRegionByIndexes(Region FirstRowRegion, ForwardIter FirstIt,
755 ForwardIter LastIt, ForwardIter LastDefinedIt,
756 int BoundIndex) {
757 IGC_ASSERT_MESSAGE(FirstRowRegion.NumElements == FirstRowRegion.Width,
758 "wrong argunent: vector region (with no vstride) was expected");
759 IGC_ASSERT_MESSAGE(FirstRowRegion.VStride == 0,
760 "wrong argunent: vector region (with no vstride) was expected");
761 IGC_ASSERT_MESSAGE(FirstIt->isDefined(),
762 "expected FirstIt and LastDefinedIt point to valid indices");
763 IGC_ASSERT_MESSAGE(LastDefinedIt->isDefined(),
764 "expected FirstIt and LastDefinedIt point to valid indices");
765 IGC_ASSERT_MESSAGE(std::distance(FirstIt, LastIt) >= static_cast<std::ptrdiff_t>(FirstRowRegion.Width),
766 "wrong argument: number of indexes must be at least equal to region width");
767
768 auto FirstRowEndIt = std::next(FirstIt, FirstRowRegion.Width);
769 if (FirstRowEndIt == LastIt)
770 return FirstRowRegion;
771
772 auto FirstDefined = std::find_if(FirstRowEndIt, LastIt, [](MaskIndex Idx) {
773 return Idx.isDefined() || Idx.isAnotherOp();
774 });
775 if (FirstDefined == LastIt || FirstDefined->isAnotherOp())
776 return FirstRowRegion;
777
778 int Stride = FirstRowRegion.Stride;
779 int Idx = FirstIt->get();
780 std::generate(std::next(FirstIt), FirstRowEndIt,
781 [Idx, Stride]() mutable { return MaskIndex{Idx += Stride}; });
782
783 llvm::Optional<int> VStride =
784 estimateVerticalStride(FirstRowRegion, FirstIt, FirstDefined);
785 if (!VStride)
786 return FirstRowRegion;
787
788 int VDistance = *VStride;
789 int Width = FirstRowRegion.Width;
790 int NumElements = FirstRowRegion.Width;
791 int HighestFirstRowElement = std::prev(FirstRowEndIt)->get();
792
793 for (auto It = FirstRowEndIt; It != LastIt; advanceSafe(It, LastIt, Width),
794 NumElements += Width, VDistance += *VStride) {
795 if (It > LastDefinedIt || std::distance(It, LastIt) < Width ||
796 HighestFirstRowElement + VDistance >= BoundIndex ||
797 !std::equal(FirstIt, FirstRowEndIt, It,
798 [VDistance](MaskIndex Reference, MaskIndex Current) {
799 return !Current.isAnotherOp() &&
800 (Current.isUndef() ||
801 Current.get() - Reference.get() == VDistance);
802 }))
803 break;
804 }
805
806 if (NumElements == Width)
807 // VStride doesn't play role when the Width is equal to NumElements.
808 // Also it prevents from writing to big value in the region.
809 VStride = 0;
810 FirstRowRegion.VStride = *VStride;
811 FirstRowRegion.NumElements = NumElements;
812 return FirstRowRegion;
813 }
814
815 // Analyzes shufflevector mask starting from \p StartIdx element of it.
816 // Finds the longest prefix of the cutted shufflevector mask that can be
817 // represented as a region of one operand of the instruction.
818 // Returns the operand and its region.
819 //
820 // For example:
821 // {0, 1, 3, 4, 25, 16 ...} -> first 4 elements form a region:
822 // <3;2,1> vstride=3, width=2, stride=1
823 ShuffleVectorAnalyzer::OperandRegionInfo
getMaskRegionPrefix(int StartIdx)824 ShuffleVectorAnalyzer::getMaskRegionPrefix(int StartIdx) {
825 IGC_ASSERT_MESSAGE(StartIdx >= 0, "Start index is out of bound");
826 IGC_ASSERT_MESSAGE(StartIdx < static_cast<int>(SI->getShuffleMask().size()),
827 "Start index is out of bound");
828
829 auto MaskVals = SI->getShuffleMask();
830 auto StartIt = std::next(MaskVals.begin(), StartIdx);
831 OperandRegionInfo Res = matchOneElemRegion(*SI, *StartIt);
832
833 if (StartIdx == MaskVals.size() - 1)
834 return Res;
835
836 std::vector<MaskIndex> SubMask;
837 makeSVIIndexesOperandIndexes(*SI, *Res.Op, StartIt, MaskVals.end(),
838 std::back_inserter(SubMask));
839
840 auto FirstAnotherOpElement =
841 std::find_if(SubMask.begin(), SubMask.end(),
842 [](MaskIndex Elem) { return Elem.isAnotherOp(); });
843 auto PastLastDefinedElement =
844 std::find_if(std::reverse_iterator(FirstAnotherOpElement),
845 std::reverse_iterator(SubMask.begin()),
846 [](MaskIndex Elem) { return Elem.isDefined(); })
847 .base();
848
849 Res.R = matchVectorRegionByIndexes(
850 std::move(Res.R), SubMask.begin(), PastLastDefinedElement,
851 cast<IGCLLVM::FixedVectorType>(Res.Op->getType())->getNumElements());
852 Res.R = matchMatrixRegionByIndexes(
853 std::move(Res.R), SubMask.begin(), SubMask.end(),
854 std::prev(PastLastDefinedElement),
855 cast<IGCLLVM::FixedVectorType>(Res.Op->getType())->getNumElements());
856 return Res;
857 }
858
859 /***********************************************************************
860 * ShuffleVectorAnalyzer::getAsUnslice : see if the shufflevector is an
861 * unslice where the "old value" is operand 0 and operand 1 is another
862 * shufflevector and operand 0 of that is the "new value"
863 *
864 * Return: start index, or -1 if it is not an unslice
865 */
getAsUnslice()866 int ShuffleVectorAnalyzer::getAsUnslice()
867 {
868 auto SI2 = dyn_cast<ShuffleVectorInst>(SI->getOperand(1));
869 if (!SI2)
870 return -1;
871 Constant *MaskVec = IGCLLVM::getShuffleMaskForBitcode(SI);
872 // Find prefix of undef or elements from operand 0.
873 unsigned OldWidth =
874 cast<IGCLLVM::FixedVectorType>(SI2->getType())->getNumElements();
875 unsigned NewWidth =
876 cast<IGCLLVM::FixedVectorType>(SI2->getOperand(0)->getType())
877 ->getNumElements();
878 unsigned Prefix = 0;
879 for (;; ++Prefix) {
880 if (Prefix == OldWidth - NewWidth)
881 break;
882 Constant *IdxC = MaskVec->getAggregateElement(Prefix);
883 if (isa<UndefValue>(IdxC))
884 continue;
885 unsigned Idx = cast<ConstantInt>(IdxC)->getZExtValue();
886 if (Idx == OldWidth)
887 break; // found end of prefix
888 if (Idx != Prefix)
889 return -1; // not part of prefix
890 }
891 // Check that the whole of SI2 operand 0 follows
892 for (unsigned i = 1; i != NewWidth; ++i) {
893 Constant *IdxC = MaskVec->getAggregateElement(Prefix + i);
894 if (isa<UndefValue>(IdxC))
895 continue;
896 if (cast<ConstantInt>(IdxC)->getZExtValue() != i + OldWidth)
897 return -1; // not got whole of SI2 operand 0
898 }
899 // Check that the remainder is undef or elements from operand 0.
900 for (unsigned i = Prefix + NewWidth; i != OldWidth; ++i) {
901 Constant *IdxC = MaskVec->getAggregateElement(i);
902 if (isa<UndefValue>(IdxC))
903 continue;
904 if (cast<ConstantInt>(IdxC)->getZExtValue() != i)
905 return -1;
906 }
907 // Check that the first Prefix elements of SI2 come from its operand 1.
908 Constant *MaskVec2 = IGCLLVM::getShuffleMaskForBitcode(SI2);
909 for (unsigned i = 0; i != Prefix; ++i) {
910 Constant *IdxC = MaskVec2->getAggregateElement(Prefix + i);
911 if (isa<UndefValue>(IdxC))
912 continue;
913 if (cast<ConstantInt>(IdxC)->getZExtValue() != i)
914 return -1;
915 }
916 // Success.
917 return Prefix;
918 }
919
920 #if LLVM_VERSION_MAJOR <= 10
921 constexpr int UndefMaskElem = -1;
922 #endif
923
924 /***********************************************************************
925 * extension of ShuffleVectorInst::isZeroEltSplatMask method
926 */
nEltSplatMask(ArrayRef<int> Mask)927 static int nEltSplatMask(ArrayRef<int> Mask) {
928 int Elt = UndefMaskElem;
929 for (int i = 0, NumElts = Mask.size(); i < NumElts; ++i) {
930 if (Mask[i] == UndefMaskElem)
931 continue;
932 if ((Elt != UndefMaskElem) && (Mask[i] != Mask[Elt]))
933 return UndefMaskElem;
934 if ((Mask[i] != UndefMaskElem) && (Elt == UndefMaskElem))
935 Elt = i;
936 }
937 return Elt;
938 }
939
940 /***********************************************************************
941 * ShuffleVectorAnalyzer::getAsSplat : if shufflevector is a splat, get the
942 * splatted input, with its vector index if the input is a vector
943 */
getAsSplat()944 ShuffleVectorAnalyzer::SplatInfo ShuffleVectorAnalyzer::getAsSplat()
945 {
946 Value *InVec1 = SI->getOperand(0);
947 Value *InVec2 = SI->getOperand(1);
948
949 SmallVector<int, 16> MaskAsInts;
950 SI->getShuffleMask(MaskAsInts);
951 int ShuffleIdx = nEltSplatMask(MaskAsInts);
952 if (ShuffleIdx == UndefMaskElem)
953 return SplatInfo(nullptr, 0);
954
955 // We have position of shuffleindex as output, turn it to real index
956 ShuffleIdx = MaskAsInts[ShuffleIdx];
957
958 // The mask is a splat. Work out which element of which input vector
959 // it refers to.
960 int InVec1NumElements =
961 cast<IGCLLVM::FixedVectorType>(InVec1->getType())->getNumElements();
962 if (ShuffleIdx >= InVec1NumElements) {
963 ShuffleIdx -= InVec1NumElements;
964 InVec1 = InVec2;
965 }
966 if (auto IE = dyn_cast<InsertElementInst>(InVec1)) {
967 if (InVec1NumElements == 1 || isa<UndefValue>(IE->getOperand(0)))
968 return SplatInfo(IE->getOperand(1), 0);
969 // Even though this is a splat, the input vector has more than one
970 // element. IRBuilder::CreateVectorSplat does this. See if the input
971 // vector is the result of an insertelement at the right place, and
972 // if so return that. Otherwise we end up allocating
973 // an unnecessarily large register.
974 if (auto ConstIdx = dyn_cast<ConstantInt>(IE->getOperand(2)))
975 if (ConstIdx->getSExtValue() == ShuffleIdx)
976 return SplatInfo(IE->getOperand(1), 0);
977 }
978 return SplatInfo(InVec1, ShuffleIdx);
979 }
980
serialize()981 Value *ShuffleVectorAnalyzer::serialize() {
982 unsigned Cost0 = getSerializeCost(0);
983 unsigned Cost1 = getSerializeCost(1);
984
985 Value *Op0 = SI->getOperand(0);
986 Value *Op1 = SI->getOperand(1);
987 Value *V = Op0;
988 bool UseOp0AsBase = Cost0 <= Cost1;
989 if (!UseOp0AsBase)
990 V = Op1;
991
992 // Expand or shink the initial value if sizes mismatch.
993 unsigned NElts =
994 cast<IGCLLVM::FixedVectorType>(SI->getType())->getNumElements();
995 unsigned M = cast<IGCLLVM::FixedVectorType>(V->getType())->getNumElements();
996 bool SkipBase = true;
997 if (M != NElts) {
998 if (auto C = dyn_cast<Constant>(V)) {
999 SmallVector<Constant *, 16> Vals;
1000 for (unsigned i = 0; i < NElts; ++i) {
1001 Type *Ty = cast<VectorType>(C->getType())->getElementType();
1002 Constant *Elt =
1003 (i < M) ? C->getAggregateElement(i) : UndefValue::get(Ty);
1004 Vals.push_back(Elt);
1005 }
1006 V = ConstantVector::get(Vals);
1007 } else {
1008 // Need to insert individual elements.
1009 V = UndefValue::get(SI->getType());
1010 SkipBase = false;
1011 }
1012 }
1013
1014 IRBuilder<> Builder(SI);
1015 for (unsigned i = 0; i < NElts; ++i) {
1016 // Undef index returns -1.
1017 int idx = SI->getMaskValue(i);
1018 if (idx < 0)
1019 continue;
1020 if (SkipBase) {
1021 if (UseOp0AsBase && idx == i)
1022 continue;
1023 if (!UseOp0AsBase && idx == i + M)
1024 continue;
1025 }
1026
1027 Value *Vi = nullptr;
1028 if (idx < (int)M)
1029 Vi = Builder.CreateExtractElement(Op0, idx, "");
1030 else
1031 Vi = Builder.CreateExtractElement(Op1, idx - M, "");
1032 if (!isa<UndefValue>(Vi))
1033 V = Builder.CreateInsertElement(V, Vi, i, "");
1034 }
1035
1036 return V;
1037 }
1038
getSerializeCost(unsigned i)1039 unsigned ShuffleVectorAnalyzer::getSerializeCost(unsigned i) {
1040 unsigned Cost = 0;
1041 Value *Op = SI->getOperand(i);
1042 if (!isa<Constant>(Op) && Op->getType() != SI->getType())
1043 Cost += cast<IGCLLVM::FixedVectorType>(Op->getType())->getNumElements();
1044
1045 unsigned NElts =
1046 cast<IGCLLVM::FixedVectorType>(SI->getType())->getNumElements();
1047 for (unsigned j = 0; j < NElts; ++j) {
1048 // Undef index returns -1.
1049 int idx = SI->getMaskValue(j);
1050 if (idx < 0)
1051 continue;
1052 // Count the number of elements out of place.
1053 unsigned M =
1054 cast<IGCLLVM::FixedVectorType>(Op->getType())->getNumElements();
1055 if ((i == 0 && idx != j) || (i == 1 && idx != j + M))
1056 Cost++;
1057 }
1058
1059 return Cost;
1060 }
1061
IVSplitter(Instruction & Inst,const unsigned * BaseOpIdx)1062 IVSplitter::IVSplitter(Instruction &Inst, const unsigned *BaseOpIdx)
1063 : Inst(Inst) {
1064
1065 ETy = Inst.getType();
1066 if (BaseOpIdx)
1067 ETy = Inst.getOperand(*BaseOpIdx)->getType();
1068
1069 Len = 1;
1070 if (auto *EVTy = dyn_cast<IGCLLVM::FixedVectorType>(ETy)) {
1071 Len = EVTy->getNumElements();
1072 ETy = EVTy->getElementType();
1073 }
1074
1075 VI32Ty = IGCLLVM::FixedVectorType::get(ETy->getInt32Ty(Inst.getContext()),
1076 Len * 2);
1077 }
1078
describeSplit(RegionType RT,size_t ElNum)1079 IVSplitter::RegionTrait IVSplitter::describeSplit(RegionType RT, size_t ElNum) {
1080 RegionTrait Result;
1081 if (RT == RegionType::LoRegion || RT == RegionType::HiRegion) {
1082 // take every second element;
1083 Result.ElStride = 2;
1084 Result.ElOffset = (RT == RegionType::LoRegion) ? 0 : 1;
1085 }
1086 else if (RT == RegionType::FirstHalf || RT == RegionType::SecondHalf) {
1087 // take every element, sequentially
1088 Result.ElStride = 1;
1089 Result.ElOffset = (RT == RegionType::FirstHalf) ? 0 : ElNum;
1090 } else {
1091 IGC_ASSERT_EXIT_MESSAGE(0, "incorrect region type");
1092 }
1093 return Result;
1094 }
1095
1096 Constant *
splitConstantVector(const SmallVectorImpl<Constant * > & KV32,RegionType RT)1097 IVSplitter::splitConstantVector(const SmallVectorImpl<Constant *> &KV32,
1098 RegionType RT) {
1099 IGC_ASSERT(KV32.size() % 2 == 0);
1100 SmallVector<Constant *, 16> Result;
1101 size_t ElNum = KV32.size() / 2;
1102 Result.reserve(ElNum);
1103 auto Split = describeSplit(RT, ElNum);
1104 for (size_t i = 0; i < ElNum; ++i) {
1105 size_t Offset = Split.ElOffset + i * Split.ElStride;
1106 IGC_ASSERT(Offset < KV32.size());
1107 Result.push_back(KV32[Offset]);
1108 }
1109 return ConstantVector::get(Result);
1110 }
1111
createSplitRegion(Type * SrcTy,IVSplitter::RegionType RT)1112 Region IVSplitter::createSplitRegion(Type *SrcTy, IVSplitter::RegionType RT) {
1113 IGC_ASSERT(SrcTy->isVectorTy());
1114 IGC_ASSERT(SrcTy->getScalarType()->isIntegerTy(32));
1115 IGC_ASSERT(cast<IGCLLVM::FixedVectorType>(SrcTy)->getNumElements() % 2 == 0);
1116
1117 size_t Len = cast<IGCLLVM::FixedVectorType>(SrcTy)->getNumElements() / 2;
1118
1119 auto Split = describeSplit(RT, Len);
1120
1121 Region R(SrcTy);
1122 R.Width = Len;
1123 R.NumElements = Len;
1124 R.VStride = 0;
1125 R.Stride = Split.ElStride;
1126 // offset is encoded in bytes
1127 R.Offset = Split.ElOffset * 4;
1128
1129 return R;
1130 }
1131
1132 // function takes 64-bit constant value (vector or scalar) and splits it
1133 // into an equivalent vector of 32-bit constant (as if it was Bitcast-ed)
convertI64ToI32(Constant & K,SmallVectorImpl<Constant * > & K32)1134 static void convertI64ToI32(Constant &K, SmallVectorImpl<Constant *> &K32) {
1135 auto I64To32 = [](const Constant &K) {
1136 // we expect only scalar types here
1137 IGC_ASSERT(!isa<VectorType>(K.getType()));
1138 IGC_ASSERT(K.getType()->isIntegerTy(64));
1139 auto *Ty32 = K.getType()->getInt32Ty(K.getContext());
1140 if (isa<UndefValue>(K)) {
1141 Constant *Undef = UndefValue::get(Ty32);
1142 return std::make_pair(Undef, Undef);
1143 }
1144 auto *KI = cast<ConstantInt>(&K);
1145 uint64_t Val64 = KI->getZExtValue();
1146 const auto UI32ValueMask = std::numeric_limits<uint32_t>::max();
1147 Constant *VLo =
1148 ConstantInt::get(Ty32, static_cast<uint32_t>(Val64 & UI32ValueMask));
1149 Constant *VHi = ConstantInt::get(Ty32, static_cast<uint32_t>(Val64 >> 32));
1150 return std::make_pair(VLo, VHi);
1151 };
1152
1153 IGC_ASSERT(K32.empty());
1154 if (!isa<VectorType>(K.getType())) {
1155 auto V32 = I64To32(K);
1156 K32.push_back(V32.first);
1157 K32.push_back(V32.second);
1158 return;
1159 }
1160 unsigned ElNum =
1161 cast<IGCLLVM::FixedVectorType>(K.getType())->getNumElements();
1162 K32.reserve(2 * ElNum);
1163 for (unsigned i = 0; i < ElNum; ++i) {
1164 auto V32 = I64To32(*K.getAggregateElement(i));
1165 K32.push_back(V32.first);
1166 K32.push_back(V32.second);
1167 }
1168 }
1169
1170 std::pair<Value *, Value *>
splitValue(Value & Val,RegionType RT1,const Twine & Name1,RegionType RT2,const Twine & Name2,bool FoldConstants)1171 IVSplitter::splitValue(Value &Val, RegionType RT1, const Twine &Name1,
1172 RegionType RT2, const Twine &Name2, bool FoldConstants) {
1173 const auto &DL = Inst.getDebugLoc();
1174 auto BaseName = Inst.getName();
1175
1176 IGC_ASSERT(Val.getType()->getScalarType()->isIntegerTy(64));
1177
1178 if (FoldConstants && isa<Constant>(Val)) {
1179 SmallVector<Constant *, 32> KV32;
1180 convertI64ToI32(cast<Constant>(Val), KV32);
1181 Value *V1 = splitConstantVector(KV32, RT1);
1182 Value *V2 = splitConstantVector(KV32, RT2);
1183 return {V1, V2};
1184 }
1185 auto *ShreddedVal = new BitCastInst(&Val, VI32Ty, BaseName + ".iv32cast", &Inst);
1186 ShreddedVal->setDebugLoc(DL);
1187
1188 auto R1 = createSplitRegion(VI32Ty, RT1);
1189 auto *V1 = R1.createRdRegion(ShreddedVal, BaseName + Name1, &Inst, DL);
1190
1191 auto R2 = createSplitRegion(VI32Ty, RT2);
1192 auto *V2 = R2.createRdRegion(ShreddedVal, BaseName + Name2, &Inst, DL);
1193 return { V1, V2 };
1194 }
1195
splitOperandLoHi(unsigned SourceIdx,bool FoldConstants)1196 IVSplitter::LoHiSplit IVSplitter::splitOperandLoHi(unsigned SourceIdx,
1197 bool FoldConstants) {
1198
1199 IGC_ASSERT(Inst.getNumOperands() > SourceIdx);
1200 return splitValueLoHi(*Inst.getOperand(SourceIdx), FoldConstants);
1201 }
splitOperandHalf(unsigned SourceIdx,bool FoldConstants)1202 IVSplitter::HalfSplit IVSplitter::splitOperandHalf(unsigned SourceIdx,
1203 bool FoldConstants) {
1204
1205 IGC_ASSERT(Inst.getNumOperands() > SourceIdx);
1206 return splitValueHalf(*Inst.getOperand(SourceIdx), FoldConstants);
1207 }
1208
splitValueLoHi(Value & V,bool FoldConstants)1209 IVSplitter::LoHiSplit IVSplitter::splitValueLoHi(Value &V, bool FoldConstants) {
1210 auto Splitted = splitValue(V, RegionType::LoRegion, ".LoSplit",
1211 RegionType::HiRegion, ".HiSplit", FoldConstants);
1212 return {Splitted.first, Splitted.second};
1213 }
splitValueHalf(Value & V,bool FoldConstants)1214 IVSplitter::HalfSplit IVSplitter::splitValueHalf(Value &V, bool FoldConstants) {
1215 auto Splitted =
1216 splitValue(V, RegionType::FirstHalf, ".FirstHalf", RegionType::SecondHalf,
1217 ".SecondHalf", FoldConstants);
1218 return {Splitted.first, Splitted.second};
1219 }
1220
combineSplit(Value & V1,Value & V2,RegionType RT1,RegionType RT2,const Twine & Name,bool Scalarize)1221 Value* IVSplitter::combineSplit(Value &V1, Value &V2, RegionType RT1,
1222 RegionType RT2, const Twine& Name,
1223 bool Scalarize) {
1224 const auto &DL = Inst.getDebugLoc();
1225
1226 IGC_ASSERT(V1.getType() == V2.getType());
1227 IGC_ASSERT(V1.getType()->isVectorTy());
1228 IGC_ASSERT(cast<VectorType>(V1.getType())->getElementType()->isIntegerTy(32));
1229
1230 // create the write-regions
1231 auto R1 = createSplitRegion(VI32Ty, RT1);
1232 auto *UndefV = UndefValue::get(VI32Ty);
1233 auto *W1 = R1.createWrRegion(UndefV, &V1, Name + "partial_join", &Inst, DL);
1234
1235 auto R2 = createSplitRegion(VI32Ty, RT2);
1236 auto *W2 = R2.createWrRegion(W1, &V2, Name + "joined", &Inst, DL);
1237
1238 auto *V64Ty =
1239 IGCLLVM::FixedVectorType::get(ETy->getInt64Ty(Inst.getContext()), Len);
1240 auto *Result = new BitCastInst(W2, V64Ty, Name, &Inst);
1241 Result->setDebugLoc(DL);
1242
1243 if (Scalarize) {
1244 IGC_ASSERT(
1245 cast<IGCLLVM::FixedVectorType>(Result->getType())->getNumElements() ==
1246 1);
1247 Result = new BitCastInst(Result, ETy->getInt64Ty(Inst.getContext()),
1248 Name + "recast", &Inst);
1249 Result->setDebugLoc(DL);
1250 }
1251 return Result;
1252
1253 }
combineLoHiSplit(const LoHiSplit & Split,const Twine & Name,bool Scalarize)1254 Value *IVSplitter::combineLoHiSplit(const LoHiSplit &Split, const Twine &Name,
1255 bool Scalarize) {
1256 IGC_ASSERT(Split.Lo);
1257 IGC_ASSERT(Split.Hi);
1258
1259 return combineSplit(*Split.Lo, *Split.Hi, RegionType::LoRegion,
1260 RegionType::HiRegion, Name, Scalarize);
1261 }
1262
combineHalfSplit(const HalfSplit & Split,const Twine & Name,bool Scalarize)1263 Value *IVSplitter::combineHalfSplit(const HalfSplit &Split, const Twine &Name,
1264 bool Scalarize) {
1265 IGC_ASSERT(Split.Left);
1266 IGC_ASSERT(Split.Right);
1267
1268 return combineSplit(*Split.Left, *Split.Right, RegionType::FirstHalf,
1269 RegionType::SecondHalf, Name, Scalarize);
1270 }
1271 /***********************************************************************
1272 * adjustPhiNodesForBlockRemoval : adjust phi nodes when removing a block
1273 *
1274 * Enter: Succ = the successor block to adjust phi nodes in
1275 * BB = the block being removed
1276 *
1277 * This modifies each phi node in Succ as follows: the incoming for BB is
1278 * replaced by an incoming for each of BB's predecessors.
1279 */
adjustPhiNodesForBlockRemoval(BasicBlock * Succ,BasicBlock * BB)1280 void genx::adjustPhiNodesForBlockRemoval(BasicBlock *Succ, BasicBlock *BB)
1281 {
1282 for (auto i = Succ->begin(), e = Succ->end(); i != e; ++i) {
1283 auto Phi = dyn_cast<PHINode>(&*i);
1284 if (!Phi)
1285 break;
1286 // For this phi node, get the incoming for BB.
1287 int Idx = Phi->getBasicBlockIndex(BB);
1288 IGC_ASSERT(Idx >= 0);
1289 Value *Incoming = Phi->getIncomingValue(Idx);
1290 // Iterate through BB's predecessors. For the first one, replace the
1291 // incoming block with the predecessor. For subsequent ones, we need
1292 // to add new phi incomings.
1293 auto pi = pred_begin(BB), pe = pred_end(BB);
1294 IGC_ASSERT(pi != pe);
1295 Phi->setIncomingBlock(Idx, *pi);
1296 for (++pi; pi != pe; ++pi)
1297 Phi->addIncoming(Incoming, *pi);
1298 }
1299 }
1300
1301 /***********************************************************************
1302 * sinkAdd : sink add(s) in address calculation
1303 *
1304 * Enter: IdxVal = the original index value
1305 *
1306 * Return: the new calculation for the index value
1307 *
1308 * This detects the case when a variable index in a region or element access
1309 * is one or more constant add/subs then some mul/shl/truncs. It sinks
1310 * the add/subs into a single add after the mul/shl/truncs, so the add
1311 * stands a chance of being baled in as a constant offset in the region.
1312 *
1313 * If add sinking is successfully applied, it may leave now unused
1314 * instructions behind, which need tidying by a later dead code removal
1315 * pass.
1316 */
sinkAdd(Value * V)1317 Value *genx::sinkAdd(Value *V) {
1318 Instruction *IdxVal = dyn_cast<Instruction>(V);
1319 if (!IdxVal)
1320 return V;
1321 // Collect the scale/trunc/add/sub/or instructions.
1322 int Offset = 0;
1323 SmallVector<Instruction *, 8> ScaleInsts;
1324 Instruction *Inst = IdxVal;
1325 int Scale = 1;
1326 bool NeedChange = false;
1327 for (;;) {
1328 if (isa<TruncInst>(Inst))
1329 ScaleInsts.push_back(Inst);
1330 else {
1331 if (!isa<BinaryOperator>(Inst))
1332 break;
1333 if (ConstantInt *CI = dyn_cast<ConstantInt>(Inst->getOperand(1))) {
1334 if (Inst->getOpcode() == Instruction::Mul) {
1335 Scale *= CI->getSExtValue();
1336 ScaleInsts.push_back(Inst);
1337 } else if (Inst->getOpcode() == Instruction::Shl) {
1338 Scale <<= CI->getSExtValue();
1339 ScaleInsts.push_back(Inst);
1340 } else if (Inst->getOpcode() == Instruction::Add) {
1341 Offset += CI->getSExtValue() * Scale;
1342 if (V != Inst)
1343 NeedChange = true;
1344 } else if (Inst->getOpcode() == Instruction::Sub) {
1345 Offset -= CI->getSExtValue() * Scale;
1346 if (IdxVal != Inst)
1347 NeedChange = true;
1348 } else if(Inst->getOpcode() == Instruction::Or) {
1349 if (!haveNoCommonBitsSet(Inst->getOperand(0),
1350 Inst->getOperand(1),
1351 Inst->getModule()->getDataLayout()))
1352 break;
1353 Offset += CI->getSExtValue() * Scale;
1354 if (V != Inst)
1355 NeedChange = true;
1356 } else
1357 break;
1358 } else
1359 break;
1360 }
1361 Inst = dyn_cast<Instruction>(Inst->getOperand(0));
1362 if (!Inst)
1363 return V;
1364 }
1365 if (!NeedChange)
1366 return V;
1367 // Clone the scale and trunc instructions, starting with the value that
1368 // was input to the add(s).
1369 for (SmallVectorImpl<Instruction *>::reverse_iterator i = ScaleInsts.rbegin(),
1370 e = ScaleInsts.rend();
1371 i != e; ++i) {
1372 Instruction *Clone = (*i)->clone();
1373 Clone->insertBefore(IdxVal);
1374 Clone->setName((*i)->getName());
1375 Clone->setOperand(0, Inst);
1376 Inst = Clone;
1377 }
1378 // Create a new add instruction.
1379 Inst = BinaryOperator::Create(
1380 Instruction::Add, Inst,
1381 ConstantInt::get(Inst->getType(), (int64_t)Offset, true /*isSigned*/),
1382 Twine("addr_add"), IdxVal);
1383 Inst->setDebugLoc(IdxVal->getDebugLoc());
1384 return Inst;
1385 }
1386
1387 /***********************************************************************
1388 * reorderBlocks : reorder blocks to increase fallthrough, and specifically
1389 * to satisfy the requirements of SIMD control flow
1390 */
1391 #define SUCCSZANY (true)
1392 #define SUCCHASINST (succ->size() > 1)
1393 #define SUCCNOINST (succ->size() <= 1)
1394 #define SUCCANYLOOP (true)
1395
1396 #define PUSHSUCC(BLK, C1, C2) \
1397 for(succ_iterator succIter = succ_begin(BLK), succEnd = succ_end(BLK); \
1398 succIter!=succEnd; ++succIter) { \
1399 llvm::BasicBlock *succ = *succIter; \
1400 if (!visitSet.count(succ) && C1 && C2) { \
1401 visitVec.push_back(succ); \
1402 visitSet.insert(succ); \
1403 break; \
1404 } \
1405 }
1406
HasSimdGotoJoinInBlock(BasicBlock * BB)1407 static bool HasSimdGotoJoinInBlock(BasicBlock *BB)
1408 {
1409 for (BasicBlock::iterator BBI = BB->begin(),
1410 BBE = BB->end();
1411 BBI != BBE; ++BBI) {
1412 auto IID = GenXIntrinsic::getGenXIntrinsicID(&*BBI);
1413 if (IID == GenXIntrinsic::genx_simdcf_goto ||
1414 IID == GenXIntrinsic::genx_simdcf_join)
1415 return true;
1416 }
1417 return false;
1418 }
1419
LayoutBlocks(Function & func,LoopInfo & LI)1420 void genx::LayoutBlocks(Function &func, LoopInfo &LI)
1421 {
1422 std::vector<llvm::BasicBlock*> visitVec;
1423 std::set<llvm::BasicBlock*> visitSet;
1424 // Insertion Position per loop header
1425 std::map<llvm::BasicBlock*, llvm::BasicBlock*> InsPos;
1426
1427 llvm::BasicBlock* entry = &(func.getEntryBlock());
1428 visitVec.push_back(entry);
1429 visitSet.insert(entry);
1430 InsPos[entry] = entry;
1431
1432 while (!visitVec.empty()) {
1433 llvm::BasicBlock* blk = visitVec.back();
1434 llvm::Loop *curLoop = LI.getLoopFor(blk);
1435 if (curLoop) {
1436 auto hd = curLoop->getHeader();
1437 if (blk == hd && InsPos.find(hd) == InsPos.end()) {
1438 InsPos[blk] = blk;
1439 }
1440 }
1441 // push: time for DFS visit
1442 PUSHSUCC(blk, SUCCANYLOOP, SUCCNOINST);
1443 if (blk != visitVec.back())
1444 continue;
1445 // push: time for DFS visit
1446 PUSHSUCC(blk, SUCCANYLOOP, SUCCHASINST);
1447 // pop: time to move the block to the right location
1448 if (blk == visitVec.back()) {
1449 visitVec.pop_back();
1450 if (curLoop) {
1451 auto hd = curLoop->getHeader();
1452 if (blk != hd) {
1453 // move the block to the beginning of the loop
1454 auto insp = InsPos[hd];
1455 IGC_ASSERT(insp);
1456 if (blk != insp) {
1457 blk->moveBefore(insp);
1458 InsPos[hd] = blk;
1459 }
1460 }
1461 else {
1462 // move the entire loop to the beginning of
1463 // the parent loop
1464 auto LoopStart = InsPos[hd];
1465 IGC_ASSERT(LoopStart);
1466 auto PaLoop = curLoop->getParentLoop();
1467 auto PaHd = PaLoop ? PaLoop->getHeader() : entry;
1468 auto insp = InsPos[PaHd];
1469 if (LoopStart == hd) {
1470 // single block loop
1471 hd->moveBefore(insp);
1472 }
1473 else {
1474 // loop-header is not moved yet, so should be at the end
1475 // use splice
1476 llvm::Function::BasicBlockListType& BBList = func.getBasicBlockList();
1477 BBList.splice(insp->getIterator(), BBList, LoopStart->getIterator(),
1478 hd->getIterator());
1479 hd->moveBefore(LoopStart);
1480 }
1481 InsPos[PaHd] = hd;
1482 }
1483 }
1484 else {
1485 auto insp = InsPos[entry];
1486 if (blk != insp) {
1487 blk->moveBefore(insp);
1488 InsPos[entry] = blk;
1489 }
1490 }
1491 }
1492 }
1493
1494 // fix the loop-exit pattern, put break-blocks into the loop
1495 for (llvm::Function::iterator blkIter = func.begin(), blkEnd = func.end();
1496 blkIter != blkEnd; ++blkIter) {
1497 llvm::BasicBlock *blk = &(*blkIter);
1498 llvm::Loop *curLoop = LI.getLoopFor(blk);
1499 bool allPredLoopExit = true;
1500 unsigned numPreds = 0;
1501 llvm::SmallPtrSet<llvm::BasicBlock *, 4> predSet;
1502 for (pred_iterator predIter = pred_begin(blk), predEnd = pred_end(blk);
1503 predIter != predEnd; ++predIter) {
1504 llvm::BasicBlock *pred = *predIter;
1505 numPreds++;
1506 llvm::Loop *predLoop = LI.getLoopFor(pred);
1507 if (curLoop == predLoop) {
1508 llvm::BasicBlock *predPred = pred->getSinglePredecessor();
1509 if (predPred) {
1510 llvm::Loop *predPredLoop = LI.getLoopFor(predPred);
1511 if (predPredLoop != curLoop &&
1512 (!curLoop || curLoop->contains(predPredLoop))) {
1513 if (!HasSimdGotoJoinInBlock(pred)) {
1514 predSet.insert(pred);
1515 } else {
1516 allPredLoopExit = false;
1517 break;
1518 }
1519 }
1520 }
1521 } else if (!curLoop || curLoop->contains(predLoop))
1522 continue;
1523 else {
1524 allPredLoopExit = false;
1525 break;
1526 }
1527 }
1528 if (allPredLoopExit && numPreds > 1) {
1529 for (SmallPtrSet<BasicBlock *, 4>::iterator predIter = predSet.begin(),
1530 predEnd = predSet.end();
1531 predIter != predEnd; ++predIter) {
1532 llvm::BasicBlock *pred = *predIter;
1533 llvm::BasicBlock *predPred = pred->getSinglePredecessor();
1534 IGC_ASSERT(predPred);
1535 pred->moveAfter(predPred);
1536 }
1537 }
1538 }
1539 }
1540
LayoutBlocks(Function & func)1541 void genx::LayoutBlocks(Function &func)
1542 {
1543 std::vector<llvm::BasicBlock*> visitVec;
1544 std::set<llvm::BasicBlock*> visitSet;
1545 // Reorder basic block to allow more fall-through
1546 llvm::BasicBlock* entry = &(func.getEntryBlock());
1547 visitVec.push_back(entry);
1548 visitSet.insert(entry);
1549
1550 while (!visitVec.empty()) {
1551 llvm::BasicBlock* blk = visitVec.back();
1552 // push in the empty successor
1553 PUSHSUCC(blk, SUCCANYLOOP, SUCCNOINST);
1554 if (blk != visitVec.back())
1555 continue;
1556 // push in the other successor
1557 PUSHSUCC(blk, SUCCANYLOOP, SUCCHASINST);
1558 // pop
1559 if (blk == visitVec.back()) {
1560 visitVec.pop_back();
1561 if (blk != entry) {
1562 blk->moveBefore(entry);
1563 entry = blk;
1564 }
1565 }
1566 }
1567 }
1568
1569 // normalize g_load with bitcasts.
1570 //
1571 // When a single g_load is being bitcast'ed to different types, clone g_loads.
normalizeGloads(Instruction * Inst)1572 bool genx::normalizeGloads(Instruction *Inst) {
1573 IGC_ASSERT(isa<LoadInst>(Inst));
1574 auto LI = cast<LoadInst>(Inst);
1575 if (getUnderlyingGlobalVariable(LI->getPointerOperand()) == nullptr)
1576 return false;
1577
1578 // collect all uses connected by bitcasts.
1579 std::set<BitCastInst *> Visited;
1580 // Uses of this loads groupped by the use type.
1581 llvm::MapVector<Type *, std::vector<BitCastInst *>> Uses;
1582 // The working list.
1583 std::vector<BitCastInst *> Insts;
1584
1585 for (auto UI : LI->users())
1586 if (auto BI = dyn_cast<BitCastInst>(UI))
1587 Insts.push_back(BI);
1588
1589 while (!Insts.empty()) {
1590 BitCastInst *BCI = Insts.back();
1591 Insts.pop_back();
1592 if (Visited.count(BCI))
1593 continue;
1594
1595 Uses[BCI->getType()].push_back(BCI);
1596 for (auto UI : BCI->users())
1597 if (auto BI = dyn_cast<BitCastInst>(UI))
1598 Insts.push_back(BI);
1599 }
1600
1601 // There are more than two uses; clone loads that can fold bitcasts.
1602 if (Uses.size() <= 1)
1603 return false;
1604
1605 // %0 = load gv
1606 // %1 = bitcast %0 to t1
1607 // %2 - bitcast %1 to t2
1608 //
1609 // ==>
1610 // %0 = load gv
1611 // %0.1 = load gv
1612 // %1 = bitcast %0 to t1
1613 // %2 - bitcast %0.1 to t2
1614 Instruction *LInst = LI;
1615 for (auto I = Uses.begin(); I != Uses.end(); ++I) {
1616 Type *Ty = I->first;
1617 if (LInst == nullptr) {
1618 LInst = LI->clone();
1619 LInst->insertAfter(LI);
1620 }
1621 Instruction *NewCI = new BitCastInst(LInst, Ty, ".clone", LInst);
1622 NewCI->moveAfter(LInst);
1623 auto &BInsts = I->second;
1624 for (auto BI : BInsts)
1625 BI->replaceAllUsesWith(NewCI);
1626 LInst = nullptr;
1627 }
1628 return true;
1629 }
1630
1631 // fold bitcast instruction into Store by change pointer type.
foldBitCastInst(Instruction * Inst)1632 Instruction *genx::foldBitCastInst(Instruction *Inst) {
1633 IGC_ASSERT(isa<LoadInst>(Inst) || isa<StoreInst>(Inst));
1634 auto LI = dyn_cast<LoadInst>(Inst);
1635 auto SI = dyn_cast<StoreInst>(Inst);
1636
1637 Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand();
1638 GlobalVariable *GV = getUnderlyingGlobalVariable(Ptr);
1639 if (!GV)
1640 return nullptr;
1641
1642 if (SI) {
1643 Value *Val = SI->getValueOperand();
1644 if (auto CI = dyn_cast<BitCastInst>(Val)) {
1645 auto SrcTy = CI->getSrcTy();
1646 auto NewPtrTy = PointerType::get(SrcTy, SI->getPointerAddressSpace());
1647 auto NewPtr = ConstantExpr::getBitCast(GV, NewPtrTy);
1648 StoreInst *NewSI = new StoreInst(CI->getOperand(0), NewPtr,
1649 /*volatile*/ SI->isVolatile(), Inst);
1650 NewSI->takeName(SI);
1651 NewSI->setDebugLoc(Inst->getDebugLoc());
1652 Inst->eraseFromParent();
1653 return NewSI;
1654 }
1655 } else if (LI && LI->hasOneUse()) {
1656 if (auto CI = dyn_cast<BitCastInst>(LI->user_back())) {
1657 auto NewPtrTy = PointerType::get(CI->getType(), LI->getPointerAddressSpace());
1658 auto NewPtr = ConstantExpr::getBitCast(GV, NewPtrTy);
1659 auto NewLI = new LoadInst(NewPtrTy->getPointerElementType(), NewPtr, "",
1660 /*volatile*/ LI->isVolatile(), Inst);
1661 NewLI->takeName(LI);
1662 NewLI->setDebugLoc(LI->getDebugLoc());
1663 CI->replaceAllUsesWith(NewLI);
1664 LI->replaceAllUsesWith(UndefValue::get(LI->getType()));
1665 LI->eraseFromParent();
1666 return NewLI;
1667 }
1668 }
1669
1670 return nullptr;
1671 }
1672
getUnderlyingGlobalVariable(const Value * V)1673 const GlobalVariable *genx::getUnderlyingGlobalVariable(const Value *V) {
1674 while (auto *BI = dyn_cast<BitCastInst>(V))
1675 V = BI->getOperand(0);
1676 while (auto *CE = dyn_cast_or_null<ConstantExpr>(V)) {
1677 if (CE->getOpcode() == CastInst::BitCast)
1678 V = CE->getOperand(0);
1679 else
1680 break;
1681 }
1682 return dyn_cast_or_null<GlobalVariable>(V);
1683 }
1684
getUnderlyingGlobalVariable(Value * V)1685 GlobalVariable *genx::getUnderlyingGlobalVariable(Value *V) {
1686 return const_cast<GlobalVariable *>(
1687 getUnderlyingGlobalVariable(const_cast<const Value *>(V)));
1688 }
1689
getUnderlyingGlobalVariable(const LoadInst * LI)1690 const GlobalVariable *genx::getUnderlyingGlobalVariable(const LoadInst *LI) {
1691 return getUnderlyingGlobalVariable(LI->getPointerOperand());
1692 }
1693
getUnderlyingGlobalVariable(LoadInst * LI)1694 GlobalVariable *genx::getUnderlyingGlobalVariable(LoadInst *LI) {
1695 return getUnderlyingGlobalVariable(LI->getPointerOperand());
1696 }
1697
isGlobalStore(Instruction * I)1698 bool genx::isGlobalStore(Instruction *I) {
1699 IGC_ASSERT(I);
1700 if (auto *SI = dyn_cast<StoreInst>(I))
1701 return isGlobalStore(SI);
1702 return false;
1703 }
1704
isGlobalStore(StoreInst * ST)1705 bool genx::isGlobalStore(StoreInst *ST) {
1706 IGC_ASSERT(ST);
1707 return getUnderlyingGlobalVariable(ST->getPointerOperand()) != nullptr;
1708 }
1709
isGlobalLoad(Instruction * I)1710 bool genx::isGlobalLoad(Instruction *I) {
1711 IGC_ASSERT(I);
1712 if (auto *LI = dyn_cast<LoadInst>(I))
1713 return isGlobalLoad(LI);
1714 return false;
1715 }
1716
isGlobalLoad(LoadInst * LI)1717 bool genx::isGlobalLoad(LoadInst *LI) {
1718 IGC_ASSERT(LI);
1719 return getUnderlyingGlobalVariable(LI->getPointerOperand()) != nullptr;
1720 }
1721
isLegalValueForGlobalStore(Value * V,Value * StorePtr)1722 bool genx::isLegalValueForGlobalStore(Value *V, Value *StorePtr) {
1723 // Value should be wrregion.
1724 auto *Wrr = dyn_cast<CallInst>(V);
1725 if (!Wrr || !GenXIntrinsic::isWrRegion(Wrr))
1726 return false;
1727
1728 // With old value obtained from load instruction with StorePtr.
1729 Value *OldVal =
1730 Wrr->getArgOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
1731 auto *LI = dyn_cast<LoadInst>(OldVal);
1732 return LI && (getUnderlyingGlobalVariable(LI->getPointerOperand()) ==
1733 getUnderlyingGlobalVariable(StorePtr));
1734 }
1735
isGlobalStoreLegal(StoreInst * ST)1736 bool genx::isGlobalStoreLegal(StoreInst *ST) {
1737 IGC_ASSERT(isGlobalStore(ST));
1738 return isLegalValueForGlobalStore(ST->getValueOperand(),
1739 ST->getPointerOperand());
1740 }
1741
1742 // The following bale will produce identity moves.
1743 // %a0 = load m
1744 // %b0 = load m
1745 // bale {
1746 // %a1 = rrd %a0, R
1747 // %b1 = wrr %b0, %a1, R
1748 // store %b1, m
1749 // }
1750 //
isIdentityBale(const Bale & B)1751 bool genx::isIdentityBale(const Bale &B) {
1752 if (!B.endsWithGStore())
1753 return false;
1754
1755 StoreInst *ST = cast<StoreInst>(B.getHead()->Inst);
1756 if (B.size() == 1) {
1757 // The value to be stored should be a load from the same global.
1758 auto LI = dyn_cast<LoadInst>(ST->getOperand(0));
1759 return LI && getUnderlyingGlobalVariable(LI->getOperand(0)) ==
1760 getUnderlyingGlobalVariable(ST->getOperand(1));
1761 }
1762 if (B.size() != 3)
1763 return false;
1764
1765 CallInst *B1 = dyn_cast<CallInst>(ST->getValueOperand());
1766 GlobalVariable *GV = getUnderlyingGlobalVariable(ST->getPointerOperand());
1767 if (!GenXIntrinsic::isWrRegion(B1) || !GV)
1768 return false;
1769 IGC_ASSERT(B1);
1770 auto B0 = dyn_cast<LoadInst>(B1->getArgOperand(0));
1771 if (!B0 || GV != getUnderlyingGlobalVariable(B0->getPointerOperand()))
1772 return false;
1773
1774 CallInst *A1 = dyn_cast<CallInst>(B1->getArgOperand(1));
1775 if (!GenXIntrinsic::isRdRegion(A1))
1776 return false;
1777 IGC_ASSERT(A1);
1778 LoadInst *A0 = dyn_cast<LoadInst>(A1->getArgOperand(0));
1779 if (!A0 || GV != getUnderlyingGlobalVariable(A0->getPointerOperand()))
1780 return false;
1781
1782 Region R1 = makeRegionFromBaleInfo(A1, BaleInfo());
1783 Region R2 = makeRegionFromBaleInfo(B1, BaleInfo());
1784 return R1 == R2;
1785 }
1786
1787 // Check that region can be represented as raw operand.
isValueRegionOKForRaw(Value * V,bool IsWrite,const GenXSubtarget * ST)1788 bool genx::isValueRegionOKForRaw(Value *V, bool IsWrite,
1789 const GenXSubtarget *ST) {
1790 IGC_ASSERT(V);
1791 switch (GenXIntrinsic::getGenXIntrinsicID(V)) {
1792 case GenXIntrinsic::genx_rdregioni:
1793 case GenXIntrinsic::genx_rdregionf:
1794 if (IsWrite)
1795 return false;
1796 break;
1797 case GenXIntrinsic::genx_wrregioni:
1798 case GenXIntrinsic::genx_wrregionf:
1799 if (!IsWrite)
1800 return false;
1801 break;
1802 default:
1803 return false;
1804 }
1805 Region R = makeRegionFromBaleInfo(cast<Instruction>(V), BaleInfo());
1806 return isRegionOKForRaw(R, ST);
1807 }
1808
isRegionOKForRaw(const genx::Region & R,const GenXSubtarget * ST)1809 bool genx::isRegionOKForRaw(const genx::Region &R, const GenXSubtarget *ST) {
1810 unsigned GRFWidth = ST ? ST->getGRFByteSize() : 32;
1811 if (R.Indirect)
1812 return false;
1813 else if (R.Offset & (GRFWidth - 1)) // GRF boundary check
1814 return false;
1815 if (R.Width != R.NumElements)
1816 return false;
1817 if (R.Stride != 1)
1818 return false;
1819 return true;
1820 }
1821
skipOptWithLargeBlock(FunctionGroup & FG)1822 bool genx::skipOptWithLargeBlock(FunctionGroup &FG) {
1823 for (auto fgi = FG.begin(), fge = FG.end(); fgi != fge; ++fgi) {
1824 auto F = *fgi;
1825 if (skipOptWithLargeBlock(*F))
1826 return true;
1827 }
1828 return false;
1829 }
1830
getInlineAsmCodes(const InlineAsm::ConstraintInfo & Info)1831 std::string genx::getInlineAsmCodes(const InlineAsm::ConstraintInfo &Info) {
1832 return Info.Codes.front();
1833 }
1834
isInlineAsmMatchingInputConstraint(const InlineAsm::ConstraintInfo & Info)1835 bool genx::isInlineAsmMatchingInputConstraint(
1836 const InlineAsm::ConstraintInfo &Info) {
1837 return isdigit(Info.Codes.front()[0]);
1838 }
1839
getInlineAsmConstraintType(StringRef Codes)1840 genx::ConstraintType genx::getInlineAsmConstraintType(StringRef Codes) {
1841 return llvm::StringSwitch<genx::ConstraintType>(Codes)
1842 .Case("r", ConstraintType::Constraint_r)
1843 .Case("rw", ConstraintType::Constraint_rw)
1844 .Case("i", ConstraintType::Constraint_i)
1845 .Case("n", ConstraintType::Constraint_n)
1846 .Case("F", ConstraintType::Constraint_F)
1847 .Case("cr", ConstraintType::Constraint_cr)
1848 .Case("a", ConstraintType::Constraint_a)
1849 .Default(ConstraintType::Constraint_unknown);
1850 }
1851
1852 unsigned
getInlineAsmMatchedOperand(const InlineAsm::ConstraintInfo & Info)1853 genx::getInlineAsmMatchedOperand(const InlineAsm::ConstraintInfo &Info) {
1854 IGC_ASSERT_MESSAGE(genx::isInlineAsmMatchingInputConstraint(Info),
1855 "Matching input expected");
1856 int OperandValue = std::stoi(Info.Codes.front());
1857 IGC_ASSERT(OperandValue >= 0);
1858 return OperandValue;
1859 }
1860
getGenXInlineAsmInfo(MDNode * MD)1861 std::vector<GenXInlineAsmInfo> genx::getGenXInlineAsmInfo(MDNode *MD) {
1862 std::vector<GenXInlineAsmInfo> Result;
1863 for (auto &MDOp : MD->operands()) {
1864 auto EntryMD = dyn_cast<MDTuple>(MDOp);
1865 IGC_ASSERT_MESSAGE(EntryMD, "error setting metadata for inline asm");
1866 IGC_ASSERT_MESSAGE(EntryMD->getNumOperands() == 3,
1867 "error setting metadata for inline asm");
1868 ConstantAsMetadata *Op0 =
1869 dyn_cast<ConstantAsMetadata>(EntryMD->getOperand(0));
1870 ConstantAsMetadata *Op1 =
1871 dyn_cast<ConstantAsMetadata>(EntryMD->getOperand(1));
1872 ConstantAsMetadata *Op2 =
1873 dyn_cast<ConstantAsMetadata>(EntryMD->getOperand(2));
1874 IGC_ASSERT_MESSAGE(Op0, "error setting metadata for inline asm");
1875 IGC_ASSERT_MESSAGE(Op1, "error setting metadata for inline asm");
1876 IGC_ASSERT_MESSAGE(Op2, "error setting metadata for inline asm");
1877 auto CTy = static_cast<genx::ConstraintType>(
1878 cast<ConstantInt>(Op0->getValue())->getZExtValue());
1879 Result.emplace_back(CTy, cast<ConstantInt>(Op1->getValue())->getSExtValue(),
1880 cast<ConstantInt>(Op2->getValue())->getZExtValue());
1881 }
1882 return Result;
1883 }
1884
getGenXInlineAsmInfo(CallInst * CI)1885 std::vector<GenXInlineAsmInfo> genx::getGenXInlineAsmInfo(CallInst *CI) {
1886 IGC_ASSERT_MESSAGE(CI->isInlineAsm(), "Inline asm expected");
1887 MDNode *MD = CI->getMetadata(genx::MD_genx_inline_asm_info);
1888 // empty constraint info
1889 if (!MD) {
1890 auto *IA = cast<InlineAsm>(IGCLLVM::getCalledValue(CI));
1891 IGC_ASSERT_MESSAGE(IA->getConstraintString().empty(),
1892 "No info only for empty constraint string");
1893 (void)IA;
1894 return std::vector<GenXInlineAsmInfo>();
1895 }
1896 return genx::getGenXInlineAsmInfo(MD);
1897 }
1898
hasConstraintOfType(const std::vector<GenXInlineAsmInfo> & ConstraintsInfo,genx::ConstraintType CTy)1899 bool genx::hasConstraintOfType(
1900 const std::vector<GenXInlineAsmInfo> &ConstraintsInfo,
1901 genx::ConstraintType CTy) {
1902 return llvm::any_of(ConstraintsInfo, [&](const GenXInlineAsmInfo &Info) {
1903 return Info.getConstraintType() == CTy;
1904 });
1905 }
1906
getInlineAsmNumOutputs(CallInst * CI)1907 unsigned genx::getInlineAsmNumOutputs(CallInst *CI) {
1908 IGC_ASSERT_MESSAGE(CI->isInlineAsm(), "Inline asm expected");
1909 unsigned NumOutputs;
1910 if (CI->getType()->isVoidTy())
1911 NumOutputs = 0;
1912 else if (auto ST = dyn_cast<StructType>(CI->getType()))
1913 NumOutputs = ST->getNumElements();
1914 else
1915 NumOutputs = 1;
1916 return NumOutputs;
1917 }
1918
1919 /* for <1 x Ty> returns Ty
1920 * for Ty returns <1 x Ty>
1921 * other cases are unsupported
1922 */
getCorrespondingVectorOrScalar(Type * Ty)1923 Type *genx::getCorrespondingVectorOrScalar(Type *Ty) {
1924 if (Ty->isVectorTy()) {
1925 IGC_ASSERT_MESSAGE(
1926 cast<IGCLLVM::FixedVectorType>(Ty)->getNumElements() == 1,
1927 "wrong argument: scalar or degenerate vector is expected");
1928 return Ty->getScalarType();
1929 }
1930 return IGCLLVM::FixedVectorType::get(Ty, 1);
1931 }
1932
1933 // info is at main template function
scalarizeOrVectorizeIfNeeded(Instruction * Inst,Type * RefType)1934 CastInst *genx::scalarizeOrVectorizeIfNeeded(Instruction *Inst, Type *RefType) {
1935 return scalarizeOrVectorizeIfNeeded(Inst, &RefType, std::next(&RefType));
1936 }
1937
1938 // info is at main template function
scalarizeOrVectorizeIfNeeded(Instruction * Inst,Instruction * InstToReplace)1939 CastInst *genx::scalarizeOrVectorizeIfNeeded(Instruction *Inst,
1940 Instruction *InstToReplace) {
1941 return scalarizeOrVectorizeIfNeeded(Inst, &InstToReplace, std::next(&InstToReplace));
1942 }
1943
getFunctionPointerFunc(Value * V)1944 Function *genx::getFunctionPointerFunc(Value *V) {
1945 Instruction *I = nullptr;
1946 for (; (I = dyn_cast<CastInst>(V)); V = I->getOperand(0))
1947 ;
1948 ConstantExpr *CE = nullptr;
1949 for (; (CE = dyn_cast<ConstantExpr>(V)) &&
1950 (CE->getOpcode() == Instruction::ExtractElement || CE->isCast());
1951 V = CE->getOperand(0))
1952 ;
1953 if (auto *F = dyn_cast<Function>(V))
1954 return F;
1955 if (auto *CV = dyn_cast<ConstantVector>(V); CV && CV->getSplatValue())
1956 return getFunctionPointerFunc(CV->getSplatValue());
1957 return nullptr;
1958 }
1959
isFuncPointerVec(Value * V)1960 bool genx::isFuncPointerVec(Value *V) {
1961 bool Res = false;
1962 if (V->getType()->isVectorTy() && isa<ConstantExpr>(V) &&
1963 cast<ConstantExpr>(V)->getOpcode() == Instruction::BitCast)
1964 Res = isFuncPointerVec(cast<ConstantExpr>(V)->getOperand(0));
1965 else if (ConstantVector *Vec = dyn_cast<ConstantVector>(V))
1966 Res = std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
1967 return getFunctionPointerFunc(V) != nullptr;
1968 });
1969 return Res;
1970 }
1971
getLogAlignment(VISA_Align Align,unsigned GRFWidth)1972 unsigned genx::getLogAlignment(VISA_Align Align, unsigned GRFWidth) {
1973 switch (Align) {
1974 case ALIGN_BYTE:
1975 return Log2_32(ByteBytes);
1976 case ALIGN_WORD:
1977 return Log2_32(WordBytes);
1978 case ALIGN_DWORD:
1979 return Log2_32(DWordBytes);
1980 case ALIGN_QWORD:
1981 return Log2_32(QWordBytes);
1982 case ALIGN_OWORD:
1983 return Log2_32(OWordBytes);
1984 case ALIGN_GRF:
1985 return Log2_32(GRFWidth);
1986 case ALIGN_2_GRF:
1987 return Log2_32(GRFWidth) + 1;
1988 default:
1989 report_fatal_error("Unknown alignment");
1990 }
1991 }
1992
getVISA_Align(unsigned LogAlignment,unsigned GRFWidth)1993 VISA_Align genx::getVISA_Align(unsigned LogAlignment, unsigned GRFWidth) {
1994 if (LogAlignment == Log2_32(ByteBytes))
1995 return ALIGN_BYTE;
1996 else if (LogAlignment == Log2_32(WordBytes))
1997 return ALIGN_WORD;
1998 else if (LogAlignment == Log2_32(DWordBytes))
1999 return ALIGN_DWORD;
2000 else if (LogAlignment == Log2_32(QWordBytes))
2001 return ALIGN_QWORD;
2002 else if (LogAlignment == Log2_32(OWordBytes))
2003 return ALIGN_OWORD;
2004 else if (LogAlignment == Log2_32(GRFWidth))
2005 return ALIGN_GRF;
2006 else if (LogAlignment == Log2_32(GRFWidth) + 1)
2007 return ALIGN_2_GRF;
2008 else
2009 report_fatal_error("Unknown log alignment");
2010 }
2011
ceilLogAlignment(unsigned LogAlignment,unsigned GRFWidth)2012 unsigned genx::ceilLogAlignment(unsigned LogAlignment, unsigned GRFWidth) {
2013 if (LogAlignment <= Log2_32(ByteBytes))
2014 return Log2_32(ByteBytes);
2015 else if (LogAlignment <= Log2_32(WordBytes))
2016 return Log2_32(WordBytes);
2017 else if (LogAlignment <= Log2_32(DWordBytes))
2018 return Log2_32(DWordBytes);
2019 else if (LogAlignment <= Log2_32(QWordBytes))
2020 return Log2_32(QWordBytes);
2021 else if (LogAlignment <= Log2_32(OWordBytes))
2022 return Log2_32(OWordBytes);
2023 else if (LogAlignment <= Log2_32(GRFWidth))
2024 return Log2_32(GRFWidth);
2025 else if (LogAlignment <= Log2_32(GRFWidth) + 1)
2026 return Log2_32(GRFWidth) + 1;
2027 else
2028 report_fatal_error("Unknown log alignment");
2029 }
2030
isWrPredRegionLegalSetP(const CallInst & WrPredRegion)2031 bool genx::isWrPredRegionLegalSetP(const CallInst &WrPredRegion) {
2032 IGC_ASSERT_MESSAGE(GenXIntrinsic::getGenXIntrinsicID(&WrPredRegion) == GenXIntrinsic::genx_wrpredregion,
2033 "wrong argument: wrpredregion intrinsic was expected");
2034 auto &NewValue = *WrPredRegion.getOperand(WrPredRegionOperand::NewValue);
2035 auto ExecSize =
2036 NewValue.getType()->isVectorTy()
2037 ? cast<IGCLLVM::FixedVectorType>(NewValue.getType())->getNumElements()
2038 : 1;
2039 auto Offset =
2040 cast<ConstantInt>(WrPredRegion.getOperand(WrPredRegionOperand::Offset))
2041 ->getZExtValue();
2042 if (ExecSize >= 32 || !isPowerOf2_64(ExecSize))
2043 return false;
2044 if (ExecSize == 32)
2045 return Offset == 0;
2046 return Offset == 0 || Offset == 16;
2047 }
2048
checkFunctionCall(Value * V,Function * F)2049 CallInst *genx::checkFunctionCall(Value *V, Function *F) {
2050 if (!V || !F)
2051 return nullptr;
2052 auto *CI = dyn_cast<CallInst>(V);
2053 if (CI && CI->getCalledFunction() == F)
2054 return CI;
2055 return nullptr;
2056 }
2057
getNumGRFsPerIndirectForRegion(const genx::Region & R,const GenXSubtarget * ST,bool Allow2D)2058 unsigned genx::getNumGRFsPerIndirectForRegion(const genx::Region &R,
2059 const GenXSubtarget *ST,
2060 bool Allow2D) {
2061 IGC_ASSERT_MESSAGE(R.Indirect, "Indirect region expected");
2062 IGC_ASSERT(ST);
2063 if (ST->hasIndirectGRFCrossing() &&
2064 // SKL+. See if we can allow GRF crossing.
2065 (Allow2D || !R.is2D())) {
2066 return 2;
2067 }
2068 return 1;
2069 }
2070
isRealGlobalVariable(const GlobalVariable & GV)2071 bool genx::isRealGlobalVariable(const GlobalVariable &GV) {
2072 if (GV.hasAttribute("genx_volatile"))
2073 return false;
2074 if (GV.hasAttribute(genx::VariableMD::VCPredefinedVariable))
2075 return false;
2076 bool IsIndexedString =
2077 std::any_of(GV.user_begin(), GV.user_end(), [](const User *Usr) {
2078 return vc::isLegalPrintFormatIndexGEP(*Usr);
2079 });
2080 if (IsIndexedString) {
2081 IGC_ASSERT_MESSAGE(std::all_of(GV.user_begin(), GV.user_end(),
2082 [](const User *Usr) {
2083 return vc::isLegalPrintFormatIndexGEP(
2084 *Usr);
2085 }),
2086 "when global is an indexed string, its users can only "
2087 "be print format index GEPs");
2088 return false;
2089 }
2090 return true;
2091 }
2092
getStructElementPaddedSize(unsigned ElemIdx,unsigned NumOperands,const StructLayout & Layout)2093 std::size_t genx::getStructElementPaddedSize(unsigned ElemIdx,
2094 unsigned NumOperands,
2095 const StructLayout &Layout) {
2096 IGC_ASSERT_MESSAGE(ElemIdx < NumOperands,
2097 "wrong argument: invalid index into a struct");
2098 if (ElemIdx == NumOperands - 1)
2099 return Layout.getSizeInBytes() - Layout.getElementOffset(ElemIdx);
2100 return Layout.getElementOffset(ElemIdx + 1) -
2101 Layout.getElementOffset(ElemIdx);
2102 }
2103
2104 // splitStructPhi : split a phi node with struct type by splitting into
2105 // struct elements
splitStructPhi(PHINode * Phi)2106 bool genx::splitStructPhi(PHINode *Phi) {
2107 StructType *Ty = cast<StructType>(Phi->getType());
2108 // Find where we need to insert the combine instructions.
2109 Instruction *CombineInsertBefore = Phi->getParent()->getFirstNonPHI();
2110 // Now split the phi.
2111 Value *Combined = UndefValue::get(Ty);
2112 // For each struct element...
2113 for (unsigned Idx = 0, e = Ty->getNumElements(); Idx != e; ++Idx) {
2114 Type *ElTy = Ty->getTypeAtIndex(Idx);
2115 // Create the new phi node.
2116 PHINode *NewPhi =
2117 PHINode::Create(ElTy, Phi->getNumIncomingValues(),
2118 Phi->getName() + ".element" + Twine(Idx), Phi);
2119 NewPhi->setDebugLoc(Phi->getDebugLoc());
2120 // Combine the new phi.
2121 Instruction *Combine = InsertValueInst::Create(
2122 Combined, NewPhi, Idx, NewPhi->getName(), CombineInsertBefore);
2123 Combine->setDebugLoc(Phi->getDebugLoc());
2124 Combined = Combine;
2125 // For each incoming...
2126 for (unsigned In = 0, InEnd = Phi->getNumIncomingValues(); In != InEnd;
2127 ++In) {
2128 // Create an extractelement to get the individual element value.
2129 // This needs to go before the terminator of the incoming block.
2130 BasicBlock *IncomingBB = Phi->getIncomingBlock(In);
2131 Value *Incoming = Phi->getIncomingValue(In);
2132 Instruction *Extract = ExtractValueInst::Create(
2133 Incoming, Idx, Phi->getName() + ".element" + Twine(Idx),
2134 IncomingBB->getTerminator());
2135 Extract->setDebugLoc(Phi->getDebugLoc());
2136 // Add as an incoming of the new phi node.
2137 NewPhi->addIncoming(Extract, IncomingBB);
2138 }
2139 }
2140 Phi->replaceAllUsesWith(Combined);
2141 Phi->eraseFromParent();
2142 return true;
2143 }
2144
splitStructPhis(Function * F)2145 bool genx::splitStructPhis(Function *F) {
2146 bool Modified = false;
2147 for (Function::iterator fi = F->begin(), fe = F->end(); fi != fe; ++fi) {
2148 BasicBlock *BB = &*fi;
2149 for (BasicBlock::iterator bi = BB->begin();;) {
2150 PHINode *Phi = dyn_cast<PHINode>(&*bi);
2151 if (!Phi)
2152 break;
2153 ++bi; // increment here as splitStructPhi removes old phi node
2154 if (isa<StructType>(Phi->getType()))
2155 Modified |= splitStructPhi(Phi);
2156 }
2157 }
2158 return Modified;
2159 }
2160
hasMemoryDeps(Instruction * L1,Instruction * L2,Value * Addr,DominatorTree * DT)2161 bool genx::hasMemoryDeps(Instruction *L1, Instruction *L2, Value *Addr,
2162 DominatorTree *DT) {
2163 // Return false for non global loads
2164 if (!(GenXIntrinsic::isVLoad(L1) && GenXIntrinsic::isVLoad(L2)) &&
2165 !(isGlobalLoad(L1) && isGlobalLoad(L2)))
2166 return false;
2167
2168 auto isKill = [=](Instruction &I) {
2169 Instruction *Inst = &I;
2170 if ((GenXIntrinsic::isVStore(Inst) || genx::isGlobalStore(Inst)) &&
2171 (Addr == Inst->getOperand(1) ||
2172 Addr == getUnderlyingGlobalVariable(Inst->getOperand(1))))
2173 return true;
2174 // OK.
2175 return false;
2176 };
2177
2178 // global loads from the same block.
2179 if (L1->getParent() == L2->getParent()) {
2180 BasicBlock::iterator I = L1->getParent()->begin();
2181 for (; &*I != L1 && &*I != L2; ++I)
2182 /*empty*/;
2183 IGC_ASSERT(&*I == L1 || &*I == L2);
2184 auto IEnd = (&*I == L1) ? L2->getIterator() : L1->getIterator();
2185 return std::any_of(I->getIterator(), IEnd, isKill);
2186 }
2187
2188 // global loads are from different blocks.
2189 //
2190 // BB1 (L1)
2191 // / \
2192 // BB3 BB2 (L2)
2193 // \ /
2194 // BB4
2195 //
2196 auto BB1 = L1->getParent();
2197 auto BB2 = L2->getParent();
2198 if (!DT->properlyDominates(BB1, BB2)) {
2199 std::swap(BB1, BB2);
2200 std::swap(L1, L2);
2201 }
2202 if (DT->properlyDominates(BB1, BB2)) {
2203 // As BB1 dominates BB2, we can recursively check BB2's predecessors, until
2204 // reaching BB1.
2205 //
2206 // check BB1 && BB2
2207 if (std::any_of(BB2->begin(), L2->getIterator(), isKill))
2208 return true;
2209 if (std::any_of(L1->getIterator(), BB1->end(), isKill))
2210 return true;
2211 std::set<BasicBlock *> Visited{BB1, BB2};
2212 std::vector<BasicBlock *> BBs;
2213 std::copy_if(pred_begin(BB2), pred_end(BB2), std::back_inserter(BBs),
2214 [Visited](BasicBlock *BB) { return !Visited.count(BB); });
2215
2216 // This visits the subgraph dominated by BB1, originated from BB2.
2217 while (!BBs.empty()) {
2218 BasicBlock *BB = BBs.back();
2219 BBs.pop_back();
2220 Visited.insert(BB);
2221
2222 // check if there is any store kill in this block.
2223 if (std::any_of(BB->begin(), BB->end(), isKill))
2224 return true;
2225
2226 // Populate not visited predecessors.
2227 std::copy_if(pred_begin(BB), pred_end(BB), std::back_inserter(BBs),
2228 [Visited](BasicBlock *BB) { return !Visited.count(BB); });
2229 }
2230
2231 // no mem deps.
2232 return false;
2233 }
2234
2235 return true;
2236 }
2237
isRdRFromGlobalLoad(Value * V)2238 bool genx::isRdRFromGlobalLoad(Value *V) {
2239 if (!GenXIntrinsic::isRdRegion(V))
2240 return false;
2241 auto *RdR = cast<CallInst>(V);
2242 auto *I = dyn_cast<Instruction>(
2243 RdR->getArgOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum));
2244 return I && isGlobalLoad(I);
2245 };
2246
isWrRToGlobalLoad(Value * V)2247 bool genx::isWrRToGlobalLoad(Value *V) {
2248 if (!GenXIntrinsic::isWrRegion(V))
2249 return false;
2250 auto *WrR = cast<CallInst>(V);
2251 auto *I = dyn_cast<Instruction>(
2252 WrR->getArgOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum));
2253 return I && isGlobalLoad(I);
2254 };
2255