1 /*========================== begin_copyright_notice ============================
2
3 Copyright (C) 2017-2021 Intel Corporation
4
5 SPDX-License-Identifier: MIT
6
7 ============================= end_copyright_notice ===========================*/
8
9 //
10 /// GenXRegionCollapsing
11 /// --------------------
12 ///
13 /// GenX region collapsing pass is function pass that collapses nested
14 /// read regions or nested write regions.
15 ///
16 /// Nested region accesses can occur in two ways (or a mixture of both):
17 ///
18 /// 1. The front end compiler deliberately generates nested region access. The
19 /// CM compiler does this for a matrix select, generating a region access for
20 /// the rows and another one for the columns, safe in the knowledge that this
21 /// pass will combine them where it can.
22 ///
23 /// 2. Two region accesses in different source code constructs (e.g. two select()
24 /// calls, either in the same or different source statements).
25 ///
26 /// The combineRegions() function is what makes the decisions on whether two
27 /// regions can be collapsed, depending on whether they are 1D or 2D, how the
28 /// rows of one fit in the rows of the other, whether each is indirect, etc.
29 ///
30 /// This pass makes an effort to combine two region accesses even if there are
31 /// multiple bitcasts (from CM format()) or up to one SExt/ZExt (from a cast) in
32 /// between.
33 ///
34 //===----------------------------------------------------------------------===//
35 #define DEBUG_TYPE "GENX_RegionCollapsing"
36
37 #include "GenX.h"
38 #include "GenXBaling.h"
39 #include "GenXUtil.h"
40
41 #include "llvm/ADT/PostOrderIterator.h"
42 #include "llvm/Analysis/CFG.h"
43 #include "llvm/Analysis/InstructionSimplify.h"
44 #include "llvm/IR/BasicBlock.h"
45 #include "llvm/IR/Constants.h"
46 #include "llvm/IR/DataLayout.h"
47 #include "llvm/IR/Dominators.h"
48 #include "llvm/IR/Function.h"
49 #include "llvm/IR/Instructions.h"
50 #include "llvm/IR/Intrinsics.h"
51 #include "llvm/InitializePasses.h"
52 #include "llvm/Support/Debug.h"
53 #include "llvm/Transforms/Utils/Local.h"
54 #include "Probe/Assertion.h"
55
56 #include "llvmWrapper/IR/DerivedTypes.h"
57
58 using namespace llvm;
59 using namespace genx;
60
61 namespace {
62
63 // GenX region collapsing pass
64 class GenXRegionCollapsing : public FunctionPass {
65 const DataLayout *DL = nullptr;
66 DominatorTree *DT = nullptr;
67 bool Modified = false;
68 public:
69 static char ID;
GenXRegionCollapsing()70 explicit GenXRegionCollapsing() : FunctionPass(ID) { }
getPassName() const71 StringRef getPassName() const override { return "GenX Region Collapsing"; }
getAnalysisUsage(AnalysisUsage & AU) const72 void getAnalysisUsage(AnalysisUsage &AU) const override {
73 AU.addRequired<DominatorTreeWrapperPass>();
74 AU.setPreservesCFG();
75 }
76 bool runOnFunction(Function &F) override;
77
78 private:
79 void runOnBasicBlock(BasicBlock *BB);
80 void processBitCast(BitCastInst *BC);
81 void processRdRegion(Instruction *InnerRd);
82 void splitReplicatingIndirectRdRegion(Instruction *Rd, Region *R);
83 void processWrRegionElim(Instruction *OuterWr);
84 Instruction *processWrRegionBitCast(Instruction *WrRegion);
85 void processWrRegionBitCast2(Instruction *WrRegion);
86 Instruction *processWrRegion(Instruction *OuterWr);
87 Instruction *processWrRegionSplat(Instruction *OuterWr);
88 bool normalizeElementType(Region *R1, Region *R2, bool PreferFirst = false);
89 bool combineRegions(const Region *OuterR, const Region *InnerR,
90 Region *CombinedR);
91 void calculateIndex(const Region *OuterR, const Region *InnerR,
92 Region *CombinedR, Value *InnerIndex, const Twine &Name,
93 Instruction *InsertBefore, const DebugLoc &DL);
94 Value *insertOp(Instruction::BinaryOps Opcode, Value *Lhs, unsigned Rhs,
95 const Twine &Name, Instruction *InsertBefore,
96 const DebugLoc &DL);
97 Value *insertOp(Instruction::BinaryOps Opcode, Value *Lhs, Value *Rhs,
98 const Twine &Name, Instruction *InsertBefore,
99 const DebugLoc &DL);
100 bool isSingleElementRdRExtract(Instruction *I);
101 };
102
103 }// end namespace llvm
104
105
106 char GenXRegionCollapsing::ID = 0;
107 namespace llvm { void initializeGenXRegionCollapsingPass(PassRegistry &); }
108 INITIALIZE_PASS_BEGIN(GenXRegionCollapsing, "GenXRegionCollapsing",
109 "GenXRegionCollapsing", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)110 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
111 INITIALIZE_PASS_END(GenXRegionCollapsing, "GenXRegionCollapsing",
112 "GenXRegionCollapsing", false, false)
113
114 // Publicly exposed interface to pass...
115 FunctionPass *llvm::createGenXRegionCollapsingPass()
116 {
117 initializeGenXRegionCollapsingPass(*PassRegistry::getPassRegistry());
118 return new GenXRegionCollapsing();
119 }
120
121 /***********************************************************************
122 * runOnFunction : run the region collapsing pass for this Function
123 */
runOnFunction(Function & F)124 bool GenXRegionCollapsing::runOnFunction(Function &F)
125 {
126 DL = &F.getParent()->getDataLayout();
127 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
128
129 // Track if there is any modification to the function.
130 bool Changed = false;
131
132 // This does a postordered depth first traversal of the CFG, processing
133 // instructions within a basic block in reverse, to ensure that we see a def
134 // after its uses (ignoring phi node uses).
135 for (po_iterator<BasicBlock *> i = po_begin(&F.getEntryBlock()),
136 e = po_end(&F.getEntryBlock());
137 i != e; ++i) {
138 // Iterate until there is no modification.
139 BasicBlock *BB = *i;
140 do {
141 Modified = false;
142 runOnBasicBlock(BB);
143 if (Modified)
144 Changed = true;
145 } while (Modified);
146 }
147
148 return Changed;
149 }
150
lowerTrunc(TruncInst * Inst)151 static bool lowerTrunc(TruncInst *Inst) {
152 Value *InValue = Inst->getOperand(0);
153 if (!GenXIntrinsic::isRdRegion(InValue))
154 return false;
155
156 Type *InElementTy = InValue->getType();
157 Type *OutElementTy = Inst->getType();
158 unsigned NumElements = 1;
159 if (auto *VT = dyn_cast<IGCLLVM::FixedVectorType>(InElementTy)) {
160 InElementTy = VT->getElementType();
161 OutElementTy = cast<VectorType>(OutElementTy)->getElementType();
162 NumElements = VT->getNumElements();
163 }
164 unsigned OutBitSize = OutElementTy->getPrimitiveSizeInBits();
165 IGC_ASSERT(OutBitSize);
166 // Do not touch truncations to i1 or vector of i1 types.
167 if (OutBitSize == 1)
168 return false;
169 unsigned Stride = InElementTy->getPrimitiveSizeInBits() / OutBitSize;
170
171 // Create the new bitcast.
172 Instruction *BC = CastInst::Create(
173 Instruction::BitCast, InValue,
174 IGCLLVM::FixedVectorType::get(OutElementTy, Stride * NumElements),
175 Inst->getName(), Inst /*InsertBefore*/);
176 BC->setDebugLoc(Inst->getDebugLoc());
177
178 // Create the new rdregion.
179 Region R(BC);
180 R.NumElements = NumElements;
181 R.Stride = Stride;
182 R.Width = NumElements;
183 R.VStride = R.Stride * R.Width;
184 Instruction *NewInst = R.createRdRegion(
185 BC, Inst->getName(), Inst /*InsertBefore*/, Inst->getDebugLoc(),
186 !isa<VectorType>(Inst->getType()) /*AllowScalar*/);
187
188 // Change uses and mark the old inst for erasing.
189 Inst->replaceAllUsesWith(NewInst);
190 return true;
191 }
192
runOnBasicBlock(BasicBlock * BB)193 void GenXRegionCollapsing::runOnBasicBlock(BasicBlock *BB) {
194 // Code simplification in block first.
195 for (auto BI = BB->begin(), E = --BB->end(); BI != E;) {
196 IGC_ASSERT(!BI->isTerminator());
197 Instruction *Inst = &*BI++;
198 if (Inst->use_empty())
199 continue;
200
201 // Turn trunc into bitcast followed by rdr. This helps region collapsing in
202 // a later stage.
203 if (auto TI = dyn_cast<TruncInst>(Inst)) {
204 Modified |= lowerTrunc(TI);
205 continue;
206 }
207
208 // Simplify
209 // %1 = call <1 x i32> @rdr(...)
210 // %2 = extractelement <1 x i32> %1, i32 0
211 // into
212 // %2 = call i32 @rdr(...)
213 //
214 if (auto EEI = dyn_cast<ExtractElementInst>(Inst)) {
215 Value *Src = EEI->getVectorOperand();
216 if (GenXIntrinsic::isRdRegion(Src) &&
217 cast<IGCLLVM::FixedVectorType>(Src->getType())->getNumElements() ==
218 1) {
219 // Create a new region with scalar output.
220 Region R(Inst);
221 Instruction *NewInst =
222 R.createRdRegion(Src, Inst->getName(), Inst /*InsertBefore*/,
223 Inst->getDebugLoc(), true /*AllowScalar*/);
224 Inst->replaceAllUsesWith(NewInst);
225 Modified = true;
226 continue;
227 }
228 }
229
230 if (Value *V = simplifyRegionInst(Inst, DL)) {
231 Inst->replaceAllUsesWith(V);
232 Modified = true;
233 continue;
234 }
235
236 // sink index calculation before region collapsing. For collapsed regions,
237 // it is more difficult to lift constant offsets.
238 static const unsigned NOT_INDEX = 255;
239 unsigned Index = NOT_INDEX;
240
241 unsigned IID = GenXIntrinsic::getGenXIntrinsicID(Inst);
242 if (GenXIntrinsic::isRdRegion(IID))
243 Index = GenXIntrinsic::GenXRegion::RdIndexOperandNum;
244 else if (GenXIntrinsic::isWrRegion(IID))
245 Index = GenXIntrinsic::GenXRegion::WrIndexOperandNum;
246 else if (isa<InsertElementInst>(Inst))
247 Index = 2;
248 else if (isa<ExtractElementInst>(Inst))
249 Index = 1;
250
251 if (Index != NOT_INDEX) {
252 Use *U = &Inst->getOperandUse(Index);
253 Value *V = sinkAdd(*U);
254 if (V != U->get()) {
255 *U = V;
256 Modified = true;
257 }
258 }
259 }
260 Modified |= SimplifyInstructionsInBlock(BB);
261
262 // This loop processes instructions in reverse, tolerating an instruction
263 // being removed during its processing, and not re-processing any new
264 // instructions added during the processing of an instruction.
265 for (Instruction *Prev = BB->getTerminator(); Prev;) {
266 Instruction *Inst = Prev;
267 Prev = nullptr;
268 if (Inst != &BB->front())
269 Prev = Inst->getPrevNode();
270 switch (GenXIntrinsic::getGenXIntrinsicID(Inst)) {
271 case GenXIntrinsic::genx_rdregioni:
272 case GenXIntrinsic::genx_rdregionf:
273 processRdRegion(Inst);
274 break;
275 case GenXIntrinsic::genx_wrregioni:
276 case GenXIntrinsic::genx_wrregionf:
277 processWrRegionElim(Inst);
278 if (!Inst->use_empty()) {
279 if (auto NewInst = processWrRegionBitCast(Inst)) {
280 Modified = true;
281 Inst = NewInst;
282 }
283 auto NewInst1 = processWrRegionSplat(Inst);
284 if (Inst != NewInst1) {
285 Modified = true;
286 Inst = NewInst1;
287 }
288
289 auto NewInst = processWrRegion(Inst);
290 processWrRegionBitCast2(NewInst);
291 if (Inst != NewInst && NewInst->use_empty()) {
292 NewInst->eraseFromParent();
293 Modified = true;
294 }
295 }
296 if (Inst->use_empty()) {
297 Inst->eraseFromParent();
298 Modified = true;
299 }
300 break;
301 default:
302 if (auto BC = dyn_cast<BitCastInst>(Inst))
303 processBitCast(BC);
304 if (isa<CastInst>(Inst) && Inst->use_empty()) {
305 // Remove bitcast that has become unused due to changes in this pass.
306 Inst->eraseFromParent();
307 Modified = true;
308 }
309 break;
310 }
311 }
312 }
313
314 /***********************************************************************
315 * createBitCast : create a bitcast, combining bitcasts where applicable
316 */
createBitCast(Value * Input,Type * Ty,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)317 static Value *createBitCast(Value *Input, Type *Ty, const Twine &Name,
318 Instruction *InsertBefore, const DebugLoc &DL) {
319 if (Input->getType() == Ty)
320 return Input;
321 if (auto BC = dyn_cast<BitCastInst>(Input))
322 Input = BC->getOperand(0);
323 if (Input->getType() == Ty)
324 return Input;
325 auto NewBC = CastInst::Create(Instruction::BitCast, Input, Ty,
326 Name, InsertBefore);
327 NewBC->setDebugLoc(DL);
328 return NewBC;
329 }
330
331 /***********************************************************************
332 * createBitCastToElementType : create a bitcast to a vector with the
333 * specified element type, combining bitcasts where applicable
334 */
createBitCastToElementType(Value * Input,Type * ElementTy,const Twine & Name,Instruction * InsertBefore,const DataLayout * DL,const DebugLoc & DbgLoc)335 static Value *createBitCastToElementType(Value *Input, Type *ElementTy,
336 const Twine &Name,
337 Instruction *InsertBefore,
338 const DataLayout *DL,
339 const DebugLoc &DbgLoc) {
340 unsigned ElBytes = ElementTy->getPrimitiveSizeInBits() / 8U;
341 if (!ElBytes) {
342 IGC_ASSERT(ElementTy->isPointerTy());
343 IGC_ASSERT(ElementTy->getPointerElementType()->isFunctionTy());
344 ElBytes = DL->getTypeSizeInBits(ElementTy) / 8;
345 }
346 unsigned InputBytes = Input->getType()->getPrimitiveSizeInBits() / 8U;
347 if (!InputBytes) {
348 Type *T = Input->getType();
349 if (T->isVectorTy())
350 T = cast<VectorType>(T)->getElementType();
351 IGC_ASSERT(T->isPointerTy());
352 IGC_ASSERT(T->getPointerElementType()->isFunctionTy());
353 InputBytes = DL->getTypeSizeInBits(T) / 8;
354 }
355 IGC_ASSERT_MESSAGE(!(InputBytes & (ElBytes - 1)), "non-integral number of elements");
356 auto Ty = IGCLLVM::FixedVectorType::get(ElementTy, InputBytes / ElBytes);
357 return createBitCast(Input, Ty, Name, InsertBefore, DbgLoc);
358 }
359
360 /***********************************************************************
361 * combineBitCastWithUser : if PossibleBC is a bitcast, and it has a single
362 * user that is also a bitcast, then combine them
363 *
364 * If combined, the two bitcast instructions are erased.
365 *
366 * This can happen because combining two rdregions with a bitcast between
367 * them can result in the bitcast being used by another bitcast that was
368 * already there.
369 */
combineBitCastWithUser(Value * PossibleBC)370 static void combineBitCastWithUser(Value *PossibleBC)
371 {
372 if (auto BC1 = dyn_cast<BitCastInst>(PossibleBC)) {
373 if (BC1->hasOneUse()) {
374 if (auto BC2 = dyn_cast<BitCastInst>(BC1->use_begin()->getUser())) {
375 Value *CombinedBC = BC1->getOperand(0);
376 if (CombinedBC->getType() != BC2->getType())
377 CombinedBC = createBitCast(BC1->getOperand(0), BC2->getType(),
378 BC2->getName(), BC2, BC2->getDebugLoc());
379 BC2->replaceAllUsesWith(CombinedBC);
380 BC2->eraseFromParent();
381 BC1->eraseFromParent();
382 }
383 }
384 }
385 }
386
387 /***********************************************************************
388 * processBitCast : process a bitcast whose input is rdregion
389 *
390 * We put the bitcast before the rdregion, in the hope that it will enable
391 * the rdregion to be baled in to something later on.
392 */
processBitCast(BitCastInst * BC)393 void GenXRegionCollapsing::processBitCast(BitCastInst *BC)
394 {
395 if (BC->getType()->getScalarType()->isIntegerTy(1))
396 return;
397 auto Rd = dyn_cast<Instruction>(BC->getOperand(0));
398
399 // check if skipping this optimization.
400 auto skip = [=] {
401 // Skip if this is not rdregion
402 if (!Rd || !GenXIntrinsic::isRdRegion(Rd))
403 return true;
404
405 Value *OldValue = Rd->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
406 if (GenXIntrinsic::isReadWritePredefReg(OldValue))
407 return true;
408
409 // Single use, do optimization.
410 if (Rd->hasOneUse())
411 return false;
412
413 // More than one uses, we check if rdr is reading from a global.
414 // If yes, still do such conversion, as bitcast could be folded into g_load.
415 while (auto CI = dyn_cast<BitCastInst>(OldValue))
416 OldValue = CI->getOperand(0);
417 auto LI = dyn_cast<LoadInst>(OldValue);
418 if (LI && getUnderlyingGlobalVariable(LI->getPointerOperand()))
419 return false;
420
421 // skip otherwise;
422 return true;
423 };
424
425 if (skip())
426 return;
427
428 // skip call above shall check for RdRegion among other things
429 IGC_ASSERT(Rd);
430 IGC_ASSERT(GenXIntrinsic::isRdRegion(Rd));
431
432 // We have a single use rdregion as the input to the bitcast.
433 // Adjust the region parameters if possible so the element type is that of
434 // the result of the bitcast, instead of the input.
435 Region ROrig = makeRegionFromBaleInfo(Rd, BaleInfo());
436 Region R = makeRegionFromBaleInfo(Rd, BaleInfo());
437 auto ElTy = BC->getType()->getScalarType();
438 IGC_ASSERT(DL);
439 if (!R.changeElementType(ElTy, DL))
440 return;
441
442 // we do not want this optimization to be applied if resulting indirect
443 // region will have non-zero stride or non-single width
444 // this will require ineffective legalization in those cases
445 bool OrigCorr = ((ROrig.Width == 1) || (ROrig.Stride == 0));
446 bool ChangedWrong = ((R.Width != 1) && (R.Stride != 0));
447 if (OrigCorr && ChangedWrong && R.Indirect)
448 return;
449
450 // Create the new bitcast.
451 IGC_ASSERT(vc::getTypeSize(ElTy, DL).inBits());
452 auto Input = Rd->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
453 auto NewBCTy = IGCLLVM::FixedVectorType::get(
454 ElTy, vc::getTypeSize(Input->getType(), DL).inBits() /
455 vc::getTypeSize(ElTy, DL).inBits());
456 auto NewBC = CastInst::Create(Instruction::BitCast, Input, NewBCTy, "", Rd);
457 NewBC->takeName(BC);
458 NewBC->setDebugLoc(BC->getDebugLoc());
459 // Create the new rdregion.
460 auto NewRd = R.createRdRegion(NewBC, "", Rd, Rd->getDebugLoc(),
461 /*AllowScalar=*/!isa<VectorType>(BC->getType()));
462 NewRd->takeName(Rd);
463 // Replace uses.
464 BC->replaceAllUsesWith(NewRd);
465 // Caller removes BC.
466 Modified = true;
467 }
468
469 /***********************************************************************
470 * processRdRegion : process a rdregion
471 *
472 * 1. If this rdregion is unused, it probably became so in the processing
473 * of a later rdregion. Erase it.
474 *
475 * 2. Otherwise, see if the input to this rdregion is the result of
476 * an earlier rdregion, and if so see if they can be combined. This can
477 * work even if there are bitcasts and up to one sext/zext between the
478 * two rdregions.
479 */
processRdRegion(Instruction * InnerRd)480 void GenXRegionCollapsing::processRdRegion(Instruction *InnerRd)
481 {
482 if (InnerRd->use_empty()) {
483 InnerRd->eraseFromParent();
484 Modified = true;
485 return;
486 }
487
488 // We use genx::makeRegionWithOffset to get a Region object for a
489 // rdregion/wrregion throughout this pass, in order to ensure that, with an
490 // index that is V+const, we get the V and const separately
491 // (in Region::Indirect and Region::Offset).
492 // Then our index calculations can ensure that the constant add remains th
493 // last thing that happens in the calculation.
494 Region InnerR = genx::makeRegionWithOffset(InnerRd,
495 /*WantParentWidth=*/true);
496
497 // Prevent region collapsing for specific src replication pattern,
498 // in order to enable swizzle optimization for Align16 instruction
499 if (InnerRd->hasOneUse()) {
500 if (auto UseInst = dyn_cast<Instruction>(InnerRd->use_begin()->getUser())) {
501 if (UseInst->getOpcode() == Instruction::FMul) {
502 auto NextInst = dyn_cast<Instruction>(UseInst->use_begin()->getUser());
503 if (NextInst &&
504 (NextInst->getOpcode() == Instruction::FAdd ||
505 NextInst->getOpcode() == Instruction::FSub) &&
506 InnerR.ElementTy->getPrimitiveSizeInBits() == 64U &&
507 InnerR.Width == 2 &&
508 InnerR.Stride == 0 &&
509 InnerR.VStride == 2)
510 return;
511 }
512 }
513 }
514
515 for (;;) {
516 Instruction *OuterRd = dyn_cast<Instruction>(InnerRd->getOperand(0));
517 // Go through any bitcasts and up to one sext/zext if necessary to find the
518 // outer rdregion.
519 Instruction *Extend = nullptr;
520 bool HadElementTypeChange = false;
521 for (;;) {
522 if (!OuterRd)
523 break; // input not result of earlier rdregion
524 if (GenXIntrinsic::isRdRegion(OuterRd)) {
525 Value *OldValue =
526 OuterRd->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
527 // Do not optimize predefined register regions.
528 if (GenXIntrinsic::isReadWritePredefReg(OldValue))
529 OuterRd = nullptr;
530 break; // found the outer rdregion
531 }
532 if (isa<SExtInst>(OuterRd) || isa<ZExtInst>(OuterRd)) {
533 if (OuterRd->getOperand(0)->getType()->getScalarType()->isIntegerTy(1)) {
534 OuterRd = nullptr;
535 break; // input not result of earlier rdregion
536 }
537 if (Extend || HadElementTypeChange) {
538 OuterRd = nullptr;
539 break; // can only have one sext/zext between the rdregions, and
540 // sext/zext not allowed if it is then subject to a bitcast
541 // that changes the element type
542 }
543 // Remember the sext/zext instruction.
544 Extend = OuterRd;
545 } else if (isa<BitCastInst>(OuterRd)) {
546 if (OuterRd->getType()->getScalarType()
547 != OuterRd->getOperand(0)->getType()->getScalarType())
548 HadElementTypeChange = true;
549 } else {
550 OuterRd = nullptr;
551 break; // input not result of earlier rdregion
552 }
553 OuterRd = dyn_cast<Instruction>(OuterRd->getOperand(0));
554 }
555 if (!OuterRd)
556 break; // no outer rdregion that we can combine with
557 Region OuterR = genx::makeRegionWithOffset(OuterRd);
558 // There was a sext/zext. Because we are going to put that after the
559 // collapsed region, we want to modify the inner region to the
560 // extend's input element type without changing the region parameters
561 // (other than scaling the offset). We know that there is no element
562 // type changing bitcast between the extend and the inner rdregion.
563 if (Extend) {
564 if (InnerR.Indirect)
565 return; // cannot cope with indexed inner region and sext/zext
566 InnerR.ElementTy = Extend->getOperand(0)->getType()->getScalarType();
567 unsigned ExtInputElementBytes
568 = InnerR.ElementTy->getPrimitiveSizeInBits() / 8U;
569 InnerR.Offset = InnerR.Offset / InnerR.ElementBytes * ExtInputElementBytes;
570 InnerR.ElementBytes = ExtInputElementBytes;
571 }
572 // See if the regions can be combined. We call normalizeElementType with
573 // InnerR as the first arg so it prefers to normalize to that region's
574 // element type if possible. That can avoid a bitcast being put after the
575 // combined rdregion, which can help baling later on.
576 LLVM_DEBUG(dbgs() << "GenXRegionCollapsing::processRdRegion:\n"
577 " OuterRd (line " << OuterRd->getDebugLoc().getLine() << "): " << *OuterRd << "\n"
578 " InnerRd (line " << InnerRd->getDebugLoc().getLine() << "): " << *InnerRd << "\n");
579 if (!normalizeElementType(&InnerR, &OuterR, /*PreferFirst=*/true)) {
580 LLVM_DEBUG(dbgs() << "Cannot normalize element type\n");
581 return;
582 }
583
584 // If it's a signle element extract from an indirect region
585 // then check if there exist some other extracts
586 if (OuterR.Indirect && (OuterR.NumElements != 1) &&
587 isSingleElementRdRExtract(InnerRd)) {
588 auto NumExtracts = llvm::count_if(OuterRd->uses(), [this](Use &U) {
589 return isSingleElementRdRExtract(cast<Instruction>(U.getUser()));
590 });
591 // If there are some more extracts except this one (InnerRd)
592 // then not combine these regions to prevent generation
593 // of extra address conversions for a combined region
594 if (NumExtracts > 1)
595 return;
596 }
597
598 Region CombinedR;
599 if (!combineRegions(&OuterR, &InnerR, &CombinedR))
600 return; // cannot combine
601
602 // If the combined region is both indirect and splat, then do not combine.
603 // Otherwise, this leads to an infinite loop as later on we split such
604 // region reads.
605 auto isIndirectSplat = [](const Region &R) {
606 if (!R.Indirect)
607 return false;
608 if (R.Width != R.NumElements && !R.VStride &&
609 !isa<VectorType>(R.Indirect->getType()))
610 return true;
611 if (R.Width == 1 || R.Stride)
612 return false;
613 return true;
614 };
615 if (isIndirectSplat(CombinedR))
616 return;
617
618 // Calculate index if necessary.
619 if (InnerR.Indirect) {
620 calculateIndex(&OuterR, &InnerR, &CombinedR,
621 InnerRd->getOperand(GenXIntrinsic::GenXRegion::RdIndexOperandNum),
622 InnerRd->getName() + ".indexcollapsed",
623 InnerRd, InnerRd->getDebugLoc());
624 }
625 // If the element type of the combined region does not match that of the
626 // outer region, we need to do a bitcast first.
627 Value *Input = OuterRd->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
628 // InnerR.ElementTy not always equal to InnerRd->getType()->getScalarType() (look above)
629 if (InnerR.ElementTy != OuterRd->getType()->getScalarType())
630 Input = createBitCastToElementType(Input, InnerR.ElementTy,
631 Input->getName() +
632 ".bitcast_before_collapse",
633 OuterRd, DL, OuterRd->getDebugLoc());
634 // Create the combined rdregion.
635 Instruction *CombinedRd = CombinedR.createRdRegion(Input,
636 InnerRd->getName() + ".regioncollapsed", InnerRd, InnerRd->getDebugLoc(),
637 !isa<VectorType>(InnerRd->getType()));
638 // If we went through sext/zext, re-instate it here.
639 Value *NewVal = CombinedRd;
640 if (Extend) {
641 auto NewCI = CastInst::Create((Instruction::CastOps)Extend->getOpcode(),
642 NewVal, InnerRd->getType(), Extend->getName(), InnerRd);
643 NewCI->setDebugLoc(Extend->getDebugLoc());
644 NewVal = NewCI;
645 }
646 // If we still don't have the right type due to bitcasts in the original
647 // code, add a bitcast here.
648 NewVal = createBitCast(NewVal, InnerRd->getType(),
649 NewVal->getName() + ".bitcast_after_collapse", InnerRd,
650 InnerRd->getDebugLoc());
651 // Replace the inner read with the new value, and erase the inner read.
652 // any other instructions between it and the outer read (inclusive) that
653 // become unused.
654 InnerRd->replaceAllUsesWith(NewVal);
655 InnerRd->eraseFromParent();
656 Modified = true;
657 // Check whether we just created a bitcast that can be combined with its
658 // user. If so, combine them.
659 combineBitCastWithUser(NewVal);
660 InnerRd = CombinedRd;
661 InnerR = genx::makeRegionWithOffset(InnerRd, /*WantParentWidth=*/true);
662 // Because the loop in runOnFunction does not re-process the new rdregion,
663 // loop back here to re-process it.
664 }
665 // InnerRd and InnerR are now the combined rdregion (or the original one if
666 // no combining was done).
667 // Check whether we have a rdregion that is both indirect and replicating,
668 // that we want to split.
669 splitReplicatingIndirectRdRegion(InnerRd, &InnerR);
670 }
671
672 /***********************************************************************
673 * splitReplicatingIndirectRdRegion : if the rdregion is both indirect and
674 * replicating, split out the indirect part so it is read only once
675 */
splitReplicatingIndirectRdRegion(Instruction * Rd,Region * R)676 void GenXRegionCollapsing::splitReplicatingIndirectRdRegion(
677 Instruction *Rd, Region *R)
678 {
679 if (!R->Indirect)
680 return;
681 if (R->Width != R->NumElements && !R->VStride
682 && !isa<VectorType>(R->Indirect->getType())) {
683 // Replicating rows. We want an indirect region that just reads
684 // one row
685 Region IndirR = *R;
686 IndirR.NumElements = IndirR.Width;
687 auto Indir = IndirR.createRdRegion(Rd->getOperand(0),
688 Rd->getName() + ".split_replicated_indir", Rd, Rd->getDebugLoc());
689 // ... and a direct region that replicates the row.
690 R->Indirect = nullptr;
691 R->Offset = 0;
692 R->Stride = 1;
693 auto NewRd = R->createRdRegion(Indir, "", Rd, Rd->getDebugLoc());
694 NewRd->takeName(Rd);
695 Rd->replaceAllUsesWith(NewRd);
696 Rd->eraseFromParent();
697 Modified = true;
698 return;
699 }
700 if (R->Width == 1 || R->Stride)
701 return;
702 // Replicating columns. We want an indirect region that just reads
703 // one column
704 Region IndirR = *R;
705 IndirR.NumElements = IndirR.NumElements / IndirR.Width;
706 IndirR.Width = 1;
707 auto Indir = IndirR.createRdRegion(Rd->getOperand(0),
708 Rd->getName() + ".split_replicated_indir", Rd, Rd->getDebugLoc());
709 // ... and a direct region that replicates the column.
710 R->Indirect = nullptr;
711 R->Offset = 0;
712 R->VStride = 1;
713 auto NewRd = R->createRdRegion(Indir, "", Rd, Rd->getDebugLoc());
714 NewRd->takeName(Rd);
715 Rd->replaceAllUsesWith(NewRd);
716 Rd->eraseFromParent();
717 }
718
719 /***********************************************************************
720 * processWrRegionElim : process a wrregion and eliminate redundant writes
721 *
722 * This detects the following code:
723 *
724 * B = wrregion(A, V1, R)
725 * C = wrregion(B, V2, R)
726 *
727 * (where "R" is a region that is identical in the two versions
728 * this can be collapsed to
729 *
730 * D = wrregion(A, V2, R)
731 *
732 */
processWrRegionElim(Instruction * OuterWr)733 void GenXRegionCollapsing::processWrRegionElim(Instruction *OuterWr)
734 {
735 auto InnerWr = dyn_cast<Instruction>(
736 OuterWr->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum));
737 if (!GenXIntrinsic::isWrRegion(InnerWr))
738 return;
739 // Only perform this optimisation if the only use is with outer - otherwise
740 // this seems to make the code spill more
741 IGC_ASSERT(InnerWr);
742 if (!InnerWr->hasOneUse())
743 return;
744
745 Region InnerR = genx::makeRegionFromBaleInfo(InnerWr, BaleInfo(),
746 /*WantParentWidth=*/true);
747 Region OuterR = genx::makeRegionFromBaleInfo(OuterWr, BaleInfo());
748 if (OuterR != InnerR)
749 return;
750 // Create the combined wrregion.
751 Instruction *CombinedWr = OuterR.createWrRegion(
752 InnerWr->getOperand(0),
753 OuterWr->getOperand(GenXIntrinsic::GenXRegion::NewValueOperandNum),
754 OuterWr->getName() + ".regioncollapsed", OuterWr, OuterWr->getDebugLoc());
755 OuterWr->replaceAllUsesWith(CombinedWr);
756 // Do not erase OuterWr here -- it gets erased by the caller.
757 Modified = true;
758 }
759
760 /***********************************************************************
761 * processWrRegionBitCast : handle a wrregion whose "new value" is a
762 * bitcast (before processing wrregion for region collapsing)
763 *
764 * Enter: Inst = the wrregion
765 *
766 * Return: replacement wrregion if any, else 0
767 *
768 * If the "new value" operand of the wrregion is a bitcast from scalar to
769 * 1-vector, or vice versa, then we can replace the wrregion with one that
770 * uses the input to the bitcast directly. This may enable later baling
771 * that would otherwise not happen.
772 *
773 * The bitcast typically arises from GenXLowering lowering an insertelement.
774 */
processWrRegionBitCast(Instruction * WrRegion)775 Instruction *GenXRegionCollapsing::processWrRegionBitCast(Instruction *WrRegion)
776 {
777 IGC_ASSERT(GenXIntrinsic::isWrRegion(WrRegion));
778 if (auto BC = dyn_cast<BitCastInst>(WrRegion->getOperand(
779 GenXIntrinsic::GenXRegion::NewValueOperandNum))) {
780 if (BC->getType()->getScalarType()
781 == BC->getOperand(0)->getType()->getScalarType()) {
782 // The bitcast is from scalar to 1-vector, or vice versa.
783 Region R = makeRegionFromBaleInfo(WrRegion, BaleInfo());
784 auto NewInst =
785 R.createWrRegion(WrRegion->getOperand(0), BC->getOperand(0), "",
786 WrRegion, WrRegion->getDebugLoc());
787 NewInst->takeName(WrRegion);
788 WrRegion->replaceAllUsesWith(NewInst);
789 WrRegion->eraseFromParent();
790 return NewInst;
791 }
792 }
793 return nullptr;
794 }
795
796 /***********************************************************************
797 * processWrRegionBitCast2 : handle a wrregion whose "new value" is a
798 * bitcast (after processing wrregion for region collapsing)
799 *
800 * Enter: WrRegion = the wrregion
801 *
802 * This does not erase WrRegion even if it becomes unused.
803 *
804 *
805 * If the "new value" operand of the wrregion is some other bitcast, then we
806 * change the wrregion to the pre-bitcast type and add new bitcasts for the
807 * "old value" input and the result. This makes it possible for the new value
808 * to be baled in to the wrregion.
809 */
processWrRegionBitCast2(Instruction * WrRegion)810 void GenXRegionCollapsing::processWrRegionBitCast2(Instruction *WrRegion)
811 {
812 auto BC = dyn_cast<BitCastInst>(WrRegion->getOperand(
813 GenXIntrinsic::GenXRegion::NewValueOperandNum));
814 if (!BC)
815 return;
816 Type *BCInputElementType = BC->getOperand(0)->getType()->getScalarType();
817 if (BCInputElementType->isIntegerTy(1))
818 return;
819
820 Value *OldValue = WrRegion->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
821 if (GenXIntrinsic::isReadWritePredefReg(OldValue))
822 return;
823
824 // Get the region params for the replacement wrregion, checking if that
825 // fails.
826 Region R = makeRegionFromBaleInfo(WrRegion, BaleInfo());
827 if (!R.changeElementType(BCInputElementType, DL))
828 return;
829 // Bitcast the "old value" input.
830 Value *OldVal = createBitCastToElementType(
831 OldValue,
832 BCInputElementType, WrRegion->getName() + ".precast", WrRegion, DL,
833 WrRegion->getDebugLoc());
834 // Create the replacement wrregion.
835 auto NewInst = R.createWrRegion(OldVal, BC->getOperand(0), "", WrRegion,
836 WrRegion->getDebugLoc());
837 NewInst->takeName(WrRegion);
838 // Cast it.
839 Value *Res = createBitCast(NewInst, WrRegion->getType(),
840 WrRegion->getName() + ".postcast", WrRegion, WrRegion->getDebugLoc());
841 WrRegion->replaceAllUsesWith(Res);
842 }
843
844 // Check whether two values are bitwise identical.
isBitwiseIdentical(Value * V1,Value * V2,DominatorTree * DT)845 static bool isBitwiseIdentical(Value *V1, Value *V2, DominatorTree *DT) {
846 IGC_ASSERT_MESSAGE(V1, "null value");
847 IGC_ASSERT_MESSAGE(V2, "null value");
848 if (V1 == V2)
849 return true;
850 if (BitCastInst *BI = dyn_cast<BitCastInst>(V1))
851 V1 = BI->getOperand(0);
852 if (BitCastInst *BI = dyn_cast<BitCastInst>(V2))
853 V2 = BI->getOperand(0);
854
855 // Special case arises from vload/vstore.
856 if (GenXIntrinsic::isVLoad(V1) && GenXIntrinsic::isVLoad(V2)) {
857 auto L1 = cast<CallInst>(V1);
858 auto L2 = cast<CallInst>(V2);
859
860 // Loads from global variables.
861 auto GV1 = getUnderlyingGlobalVariable(L1->getOperand(0));
862 auto GV2 = getUnderlyingGlobalVariable(L2->getOperand(0));
863 Value *Addr = L1->getOperand(0);
864 if (GV1 && GV1 == GV2)
865 // OK.
866 Addr = GV1;
867 else if (L1->getOperand(0) != L2->getOperand(0))
868 // Check if loading from the same location.
869 return false;
870 else if (!isa<AllocaInst>(Addr))
871 // Check if this pointer is local and only used in vload/vstore.
872 return false;
873
874 // Check if there is no store to the same location in between.
875 return !genx::hasMemoryDeps(L1, L2, Addr, DT);
876 }
877
878 // Cannot prove.
879 return false;
880 }
881
882 /***********************************************************************
883 * processWrRegion : process a wrregion
884 *
885 * Enter: OuterWr = the wrregion instruction that we will attempt to use as
886 * the outer wrregion and collapse with inner ones
887 *
888 * Return: the replacement wrregion if any, otherwise OuterWr
889 *
890 * This detects the following code:
891 *
892 * B = rdregion(A, OuterR)
893 * C = wrregion(B, V, InnerR)
894 * D = wrregion(A, C, OuterR)
895 *
896 * (where "InnerR" and "OuterR" are the region parameters). This code can
897 * be collapsed to
898 *
899 * D = wrregion(A, V, CombinedR)
900 *
901 * We want to do innermost wrregion combining first, but this pass visits
902 * instructions in the wrong order for that. So, when we see a wrregion
903 * here, we use recursion to scan back to find the innermost one and then work
904 * forwards to where we started.
905 */
processWrRegion(Instruction * OuterWr)906 Instruction *GenXRegionCollapsing::processWrRegion(Instruction *OuterWr)
907 {
908 IGC_ASSERT(OuterWr);
909 // Find the inner wrregion, skipping bitcasts.
910 auto InnerWr = dyn_cast<Instruction>(
911 OuterWr->getOperand(GenXIntrinsic::GenXRegion::NewValueOperandNum));
912 while (InnerWr && isa<BitCastInst>(InnerWr))
913 InnerWr = dyn_cast<Instruction>(InnerWr->getOperand(0));
914 if (!GenXIntrinsic::isWrRegion(InnerWr))
915 return OuterWr;
916 // Process inner wrregions first, recursively.
917 InnerWr = processWrRegion(InnerWr);
918 // Now process this one.
919 // Find the associated rdregion of the outer region, skipping bitcasts,
920 // and check it has the right region parameters.
921 IGC_ASSERT(InnerWr);
922 auto OuterRd = dyn_cast<Instruction>(InnerWr->getOperand(0));
923 while (OuterRd && isa<BitCastInst>(OuterRd))
924 OuterRd = dyn_cast<Instruction>(OuterRd->getOperand(0));
925 if (!GenXIntrinsic::isRdRegion(OuterRd))
926 return OuterWr;
927 IGC_ASSERT(OuterRd);
928 if (!isBitwiseIdentical(OuterRd->getOperand(0), OuterWr->getOperand(0), DT))
929 return OuterWr;
930 Region InnerR = genx::makeRegionWithOffset(InnerWr, /*WantParentWidth=*/true);
931 Region OuterR = genx::makeRegionWithOffset(OuterWr);
932 if (OuterR != genx::makeRegionWithOffset(OuterRd))
933 return OuterWr;
934 // See if the regions can be combined.
935 LLVM_DEBUG(dbgs() << "GenXRegionCollapsing::processWrRegion:\n"
936 " OuterWr (line " << OuterWr->getDebugLoc().getLine() << "): " << *OuterWr << "\n"
937 " InnerWr (line " << InnerWr->getDebugLoc().getLine() << "): " << *InnerWr << "\n");
938 if (!normalizeElementType(&OuterR, &InnerR)) {
939 LLVM_DEBUG(dbgs() << "Cannot normalize element type\n");
940 return OuterWr;
941 }
942 Region CombinedR;
943 if (!combineRegions(&OuterR, &InnerR, &CombinedR))
944 return OuterWr; // cannot combine
945 // Calculate index if necessary.
946 if (InnerR.Indirect) {
947 calculateIndex(&OuterR, &InnerR, &CombinedR,
948 InnerWr->getOperand(GenXIntrinsic::GenXRegion::WrIndexOperandNum),
949 InnerWr->getName() + ".indexcollapsed", OuterWr, InnerWr->getDebugLoc());
950 }
951 // Bitcast inputs if necessary.
952 Value *OldValInput = OuterRd->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
953 OldValInput = createBitCastToElementType(OldValInput, InnerR.ElementTy,
954 OldValInput->getName() + ".bitcast_before_collapse", OuterWr, DL, OuterWr->getDebugLoc());
955 Value *NewValInput = InnerWr->getOperand(GenXIntrinsic::GenXRegion::NewValueOperandNum);
956 NewValInput = createBitCastToElementType(NewValInput, InnerR.ElementTy,
957 NewValInput->getName() + ".bitcast_before_collapse", OuterWr, DL, OuterWr->getDebugLoc());
958 // Create the combined wrregion.
959 Instruction *CombinedWr = CombinedR.createWrRegion(
960 OldValInput, NewValInput, InnerWr->getName() + ".regioncollapsed",
961 OuterWr, InnerWr->getDebugLoc());
962 // Bitcast to the original type if necessary.
963 Value *Res = createBitCast(CombinedWr, OuterWr->getType(),
964 CombinedWr->getName() + ".cast", OuterWr,
965 InnerWr->getDebugLoc());
966 // Replace all uses.
967 OuterWr->replaceAllUsesWith(Res);
968 // Do not erase OuterWr here, as (if this function recursed to process an
969 // inner wrregion first) OuterWr might be the same as Prev in the loop in
970 // runOnFunction(). For a recursive call of processWrRegion, it will
971 // eventually get visited and then erased as it has no uses. For an outer
972 // call of processWrRegion, OuterWr is erased by the caller.
973 Modified = true;
974 return CombinedWr;
975 }
976
977 /***********************************************************************
978 * processWrRegionSplat : process a wrregion
979 *
980 * Enter: OuterWr = the wrregion instruction that we will attempt to use as
981 * the outer wrregion and collapse with inner ones
982 *
983 * Return: the replacement wrregion if any, otherwise OuterWr
984 *
985 * This detects the following code:
986 *
987 * C = wrregion(undef, V, InnerR)
988 * D = wrregion(undef, C, OuterR)
989 *
990 * (where "InnerR" and "OuterR" are the region parameters). This code can
991 * be collapsed to
992 *
993 * D = wrregion(undef, V, CombinedR)
994 *
995 * We want to do innermost wrregion combining first, but this pass visits
996 * instructions in the wrong order for that. So, when we see a wrregion
997 * here, we use recursion to scan back to find the innermost one and then work
998 * forwards to where we started.
999 */
processWrRegionSplat(Instruction * OuterWr)1000 Instruction *GenXRegionCollapsing::processWrRegionSplat(Instruction *OuterWr)
1001 {
1002 IGC_ASSERT(OuterWr);
1003 // Find the inner wrregion, skipping bitcasts.
1004 auto InnerWr = dyn_cast<Instruction>(
1005 OuterWr->getOperand(GenXIntrinsic::GenXRegion::NewValueOperandNum));
1006 while (InnerWr && isa<BitCastInst>(InnerWr))
1007 InnerWr = dyn_cast<Instruction>(InnerWr->getOperand(0));
1008 if (!GenXIntrinsic::isWrRegion(InnerWr))
1009 return OuterWr;
1010 // Process inner wrregions first, recursively.
1011 InnerWr = processWrRegionSplat(InnerWr);
1012
1013 // Now process this one.
1014 auto InnerSrc = dyn_cast<Constant>(InnerWr->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum));
1015 if (!InnerSrc)
1016 return OuterWr;
1017 // Ensure that the combined region is well-defined.
1018 if (InnerSrc->getType()->getScalarSizeInBits() !=
1019 OuterWr->getType()->getScalarSizeInBits())
1020 return OuterWr;
1021
1022 auto OuterSrc = dyn_cast<Constant>(OuterWr->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum));
1023 if (!OuterSrc)
1024 return OuterWr;
1025 if (isa<UndefValue>(InnerSrc)) {
1026 // OK.
1027 } else {
1028 auto InnerSplat = InnerSrc->getSplatValue();
1029 auto OuterSplat = OuterSrc->getSplatValue();
1030 if (!InnerSplat || !OuterSplat || InnerSplat != OuterSplat)
1031 return OuterWr;
1032 }
1033
1034 Region InnerR = genx::makeRegionWithOffset(InnerWr, /*WantParentWidth=*/true);
1035 Region OuterR = genx::makeRegionWithOffset(OuterWr);
1036 Region CombinedR;
1037 if (!combineRegions(&OuterR, &InnerR, &CombinedR))
1038 return OuterWr; // cannot combine
1039 // Calculate index if necessary.
1040 if (InnerR.Indirect) {
1041 calculateIndex(&OuterR, &InnerR, &CombinedR,
1042 InnerWr->getOperand(GenXIntrinsic::GenXRegion::WrIndexOperandNum),
1043 InnerWr->getName() + ".indexcollapsed", OuterWr, InnerWr->getDebugLoc());
1044 }
1045 // Bitcast inputs if necessary.
1046 Value *OldValInput = OuterSrc;
1047 Value *NewValInput = InnerWr->getOperand(1);
1048 NewValInput = createBitCastToElementType(NewValInput, OuterWr->getType()->getScalarType(),
1049 NewValInput->getName() + ".bitcast_before_collapse", OuterWr, DL, OuterWr->getDebugLoc());
1050 // Create the combined wrregion.
1051 Instruction *CombinedWr = CombinedR.createWrRegion(
1052 OldValInput, NewValInput, InnerWr->getName() + ".regioncollapsed",
1053 OuterWr, InnerWr->getDebugLoc());
1054 // Bitcast to the original type if necessary.
1055 Value *Res = createBitCast(CombinedWr, OuterWr->getType(),
1056 CombinedWr->getName() + ".cast", OuterWr,
1057 InnerWr->getDebugLoc());
1058 // Replace all uses.
1059 OuterWr->replaceAllUsesWith(Res);
1060 // Do not erase OuterWr here, as (if this function recursed to process an
1061 // inner wrregion first) OuterWr might be the same as Prev in the loop in
1062 // runOnFunction(). For a recursive call of processWrRegionSplat, it will
1063 // eventually get visited and then erased as it has no uses. For an outer
1064 // call of processWrRegionSplat, OuterWr is erased by the caller.
1065 Modified = true;
1066 return CombinedWr;
1067 }
1068
1069 /***********************************************************************
1070 * normalizeElementType : where two regions have different element size,
1071 * make them the same if possible
1072 *
1073 * Enter: R1 = first region
1074 * R2 = second region
1075 * PreferFirst = true to prefer the first region's element type
1076 *
1077 * Return: false if failed
1078 *
1079 * If PreferFirst is false, this uses the larger element size if everything is
1080 * suitably aligned and the region with the smaller element size can be
1081 * converted to the larger element size.
1082 *
1083 * Otherwise, it uses the smaller element size if the region with the
1084 * larger element size can be converted to the smaller element size.
1085 */
normalizeElementType(Region * R1,Region * R2,bool PreferFirst)1086 bool GenXRegionCollapsing::normalizeElementType(Region *R1, Region *R2,
1087 bool PreferFirst)
1088 {
1089 if (R1->ElementBytes == R2->ElementBytes)
1090 return true; // nothing to do
1091 LLVM_DEBUG(dbgs() << "Before normalizeElementType:\n"
1092 " R1: " << *R1 << "\n"
1093 " R2: " << *R2 << "\n");
1094 // Set BigR to the region with the bigger element size, and SmallR to the
1095 // region with the smaller element size.
1096 bool PreferSmall = false;
1097 Region *BigR = nullptr, *SmallR = nullptr;
1098 if (R1->ElementBytes > R2->ElementBytes) {
1099 BigR = R1;
1100 SmallR = R2;
1101 } else {
1102 BigR = R2;
1103 SmallR = R1;
1104 PreferSmall = PreferFirst;
1105 }
1106 // Try the smaller element size first if it is preferred by the caller.
1107 if (PreferSmall)
1108 if (!BigR->Indirect) // big region not indirect
1109 if (BigR->changeElementType(SmallR->ElementTy, DL))
1110 return true;
1111 // Then try the bigger element size.
1112 if (!SmallR->Indirect) // small region not indirect
1113 if (SmallR->changeElementType(BigR->ElementTy, DL))
1114 return true;
1115 // Then try the smaller element size.
1116 if (!PreferSmall)
1117 if (!BigR->Indirect) // big region not indirect
1118 if (BigR->changeElementType(SmallR->ElementTy, DL))
1119 return true;
1120 return false;
1121 }
1122
1123 /***********************************************************************
1124 * combineRegions : combine two regions if possible
1125 *
1126 * Enter: OuterR = Region struct for outer region
1127 * InnerR = Region struct for inner region
1128 * CombinedR = Region struct to write combined region into
1129 *
1130 * Return: true if combining is possible
1131 *
1132 * If combining is possible, this function sets up CombinedR. However,
1133 * CombinedR->Offset and CombinedR->Indirect are set assuming that the
1134 * inner region is direct.
1135 *
1136 * If OuterR->ElementTy != InnerR->ElementTy, this algo cannot determine
1137 * CombinedR->ElementTy, as the type depends on the order of respective
1138 * wr/rd regions (it should be the type of the last one).
1139 */
combineRegions(const Region * OuterR,const Region * InnerR,Region * CombinedR)1140 bool GenXRegionCollapsing::combineRegions(const Region *OuterR,
1141 const Region *InnerR, Region *CombinedR)
1142 {
1143 LLVM_DEBUG(dbgs() << "GenXRegionCollapsing::combineRegions\n"
1144 " OuterR: " << *OuterR << "\n"
1145 " InnerR: " << *InnerR << "\n");
1146 if (InnerR->Indirect && isa<VectorType>(InnerR->Indirect->getType()))
1147 return false; // multi indirect not supported
1148 if (OuterR->Indirect && isa<VectorType>(OuterR->Indirect->getType()))
1149 return false; // multi indirect not supported
1150 if (OuterR->Mask)
1151 return false; // outer region predicated, cannot combine
1152 *CombinedR = *InnerR;
1153 CombinedR->Indirect = OuterR->Indirect;
1154 CombinedR->Stride *= OuterR->Stride;
1155 CombinedR->VStride *= OuterR->Stride;
1156 unsigned ElOffset = InnerR->Offset / InnerR->ElementBytes;
1157 if (OuterR->is2D()) {
1158 // Outer region is 2D: create the combined offset. For outer 2D
1159 // and inner indirect, what CombinedR->Offset is set to here is
1160 // ignored and overwritten by calculateIndex(), so it does not matter
1161 // that it is incorrect in that case.
1162 ElOffset = ElOffset / OuterR->Width * OuterR->VStride
1163 + ElOffset % OuterR->Width * OuterR->Stride;
1164 } else {
1165 // Outer region is 1D: create the combined offset. For the benefit
1166 // of inner indirect, where InnerR->Offset is just an offset from
1167 // InnerR->Indirect, we cope with InnerR->Offset being apparently
1168 // out of range (negative or too big).
1169 ElOffset *= OuterR->Stride;
1170 }
1171 CombinedR->Offset = OuterR->Offset + ElOffset * InnerR->ElementBytes;
1172 if (!OuterR->is2D()) {
1173 LLVM_DEBUG(dbgs() << "outer 1D: CombinedR: " << *CombinedR << "\n");
1174 return true; // outer region is 1D, can always combine
1175 }
1176 if (InnerR->isScalar()) {
1177 LLVM_DEBUG(dbgs() << "inner scalar/splat: CombinedR: " << *CombinedR << "\n");
1178 return true; // inner region is scalar/splat, can always combine
1179 }
1180 if (InnerR->Indirect) {
1181 // Indirect inner region. Can combine as long as inner vstride is a
1182 // multiple of outer width, and it in turn is a multiple of inner parent
1183 // width.
1184 if (InnerR->ParentWidth && !(InnerR->VStride % (int)OuterR->Width)
1185 && !(OuterR->Width % InnerR->ParentWidth)) {
1186 CombinedR->VStride = InnerR->VStride / OuterR->Width * OuterR->VStride;
1187 LLVM_DEBUG(dbgs() << "inner indirect: CombinedR: " << *CombinedR << "\n");
1188 return true;
1189 }
1190 LLVM_DEBUG(dbgs() << "inner indirect: failed\n");
1191 return false;
1192 }
1193 // Inner region is not indirect.
1194 unsigned StartEl = InnerR->Offset / InnerR->ElementBytes;
1195 unsigned StartRow = StartEl / OuterR->Width;
1196 if (!InnerR->is2D()) {
1197 // Inner region is 1D but outer region is 2D.
1198 unsigned EndEl = StartEl + (InnerR->NumElements - 1) * InnerR->Stride;
1199 unsigned EndRow = EndEl / OuterR->Width;
1200 if (StartRow == EndRow) {
1201 // The whole 1D inner region fits in a row of the outer region.
1202 LLVM_DEBUG(dbgs() << "inner 1D outer 2D, fits in row: CombinedR: " << *CombinedR << "\n");
1203 return true;
1204 }
1205 if (EndRow == StartRow + 1 && !(InnerR->NumElements % 2)) {
1206 unsigned MidEl = StartEl + InnerR->NumElements / 2 * InnerR->Stride;
1207 if (InnerR->Stride > 0 && (unsigned)(MidEl - (EndRow * OuterR->Width))
1208 < (unsigned)InnerR->Stride) {
1209 // The 1D inner region is evenly split between two adjacent rows of
1210 // the outer region.
1211 CombinedR->VStride = (MidEl % OuterR->Width - StartEl % OuterR->Width)
1212 * OuterR->Stride + OuterR->VStride;
1213 CombinedR->Width = InnerR->NumElements / 2;
1214 LLVM_DEBUG(dbgs() << "inner 1D outer 2D, split between two rows: CombinedR: " << *CombinedR << "\n");
1215 return true;
1216 }
1217 }
1218 unsigned BeyondEndEl = EndEl + InnerR->Stride;
1219 if (BeyondEndEl % OuterR->Width == StartEl % OuterR->Width
1220 && !(OuterR->Width % InnerR->Stride)) {
1221 // The 1D inner region is evenly split between N adjacent rows of the
1222 // outer region, starting in the same column for each row.
1223 CombinedR->Width = OuterR->Width / InnerR->Stride;
1224 CombinedR->VStride = OuterR->VStride;
1225 LLVM_DEBUG(dbgs() << "inner 1D outer 2D, split between N rows: CombinedR: " << *CombinedR << "\n");
1226 return true;
1227 }
1228 LLVM_DEBUG(dbgs() << "inner 1D outer 2D, fail\n");
1229 return false; // All other 1D inner region cases fail.
1230 }
1231 if (!(InnerR->VStride % (int)OuterR->Width)) {
1232 // Inner vstride is a whole number of outer rows.
1233 CombinedR->VStride = OuterR->VStride * InnerR->VStride / (int)OuterR->Width;
1234 if (!InnerR->Indirect) {
1235 // For a direct inner region, calculate whether we can combine.
1236 unsigned StartEl = InnerR->Offset / InnerR->ElementBytes;
1237 unsigned StartRow = StartEl / OuterR->Width;
1238 unsigned EndRowOfFirstRow = (StartEl + (InnerR->Width - 1) * InnerR->Stride)
1239 / OuterR->Width;
1240 if (StartRow == EndRowOfFirstRow) {
1241 // Each row of inner region is within a row of outer region, starting
1242 // at the same column.
1243 LLVM_DEBUG(dbgs() << "row within row: CombinedR: " << *CombinedR << "\n");
1244 return true;
1245 }
1246 } else {
1247 // For an indirect inner region, use parent width to tell whether we can
1248 // combine.
1249 if (InnerR->ParentWidth && !(OuterR->Width % InnerR->ParentWidth)) {
1250 LLVM_DEBUG(dbgs() << "inner indirect, parentwidth ok: CombinedR: " << *CombinedR << "\n");
1251 return true;
1252 }
1253 }
1254 }
1255 // We could handle other cases like:
1256 // - each row of inner region enclosed in a row of outer region
1257 // but with a different column offset
1258 LLVM_DEBUG(dbgs() << "failed\n");
1259 return false;
1260 }
1261
1262 /***********************************************************************
1263 * calculateIndex : calculate index in the case that the inner region is
1264 * indirect
1265 *
1266 * Enter: OuterR, InnerR = outer and inner regions
1267 * CombinedR = combined region set up by combineRegions()
1268 * InnerIndex = variable index for inner region, including the
1269 * constant offset add that was extracted by the Region
1270 * constructor into InnerR->Offset
1271 * Name = name for new instruction(s)
1272 * InsertBefore = insert before this instruction
1273 * DL = debug loc for new instruction(s)
1274 *
1275 * This sets up CombinedR->Indirect and CombinedR->Offset.
1276 *
1277 * A Region has the offset set up as follows:
1278 *
1279 * - For a direct region, R.Offset is the constant offset in bytes and
1280 * R.Indirect is 0.
1281 *
1282 * - Normally, for an indirect region, R.Offset is 0 and R.Indirect is the
1283 * Value used for the offset (in bytes).
1284 *
1285 * - But if the Value used for the offset is an add constant, then R.Offset
1286 * is the constant offset and R.Indirect is the other operand of the add.
1287 *
1288 * In some code paths, this function needs the actual index of the inner region,
1289 * rather than the R.Offset and R.Indirect parts separated out by the Region
1290 * constructor. Thus it is passed InnerIndex, which is that actual index value.
1291 */
calculateIndex(const Region * OuterR,const Region * InnerR,Region * CombinedR,Value * InnerIndex,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)1292 void GenXRegionCollapsing::calculateIndex(const Region *OuterR,
1293 const Region *InnerR, Region *CombinedR, Value *InnerIndex,
1294 const Twine &Name, Instruction *InsertBefore, const DebugLoc &DL)
1295 {
1296 if (!OuterR->is2D()) {
1297 // Outer region is 1D. We can leave CombinedR->Offset as
1298 // set by combineRegions, but we need to add the indices together, scaling
1299 // the inner one by the outer region's stride.
1300 Value *Idx = InnerR->Indirect;
1301 if (OuterR->Stride != 1) {
1302 Idx = insertOp(Instruction::Mul, Idx, OuterR->Stride, Name,
1303 InsertBefore, DL);
1304 LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1305 }
1306 if (OuterR->Indirect) {
1307 Idx = insertOp(Instruction::Add, Idx, OuterR->Indirect, Name,
1308 InsertBefore, DL);
1309 LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1310 }
1311 CombinedR->Indirect = Idx;
1312 LLVM_DEBUG(dbgs() << " calculateIndex result(1d): CombinedR: " << *CombinedR << "\n");
1313 return;
1314 }
1315 // Outer region is 2D. We need to split the inner region's index into row
1316 // and column of the outer region, then recombine. We are using InnerIndex,
1317 // which includes any constant offset add, so we need to adjust
1318 // CombinedR->Offset so it does not include InnerR->Offset.
1319 CombinedR->Offset = OuterR->Offset;
1320 LLVM_DEBUG(dbgs() << " calculateIndex: Offset now " << CombinedR->Offset << "\n");
1321 Value *Col = insertOp(Instruction::URem, InnerIndex,
1322 OuterR->Width * OuterR->ElementBytes,
1323 Name, InsertBefore, DL);
1324 LLVM_DEBUG(dbgs() << " calculateIndex: " << *Col << "\n");
1325 Value *Row = insertOp(Instruction::UDiv, InnerIndex,
1326 OuterR->Width * OuterR->ElementBytes,
1327 Name, InsertBefore, DL);
1328 LLVM_DEBUG(dbgs() << " calculateIndex: " << *Row << "\n");
1329 Value *Idx = nullptr;
1330 if (!(OuterR->VStride % OuterR->Stride)) {
1331 // We need to multply Row by VStride and Col by Stride. However, Stride
1332 // divides VStride evenly, so we can common up the multiply by Stride.
1333 Idx = insertOp(Instruction::Mul, Row,
1334 OuterR->VStride * OuterR->ElementBytes / OuterR->Stride,
1335 Name, InsertBefore, DL);
1336 LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1337 Idx = insertOp(Instruction::Add, Idx, Col, Name, InsertBefore, DL);
1338 LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1339 Idx = insertOp(Instruction::Mul, Idx, OuterR->Stride, Name, InsertBefore, DL);
1340 LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1341 } else {
1342 // Need to do Row*VStride and Col*Stride separately.
1343 Idx = insertOp(Instruction::Mul, Row,
1344 OuterR->VStride * OuterR->ElementBytes, Name, InsertBefore, DL);
1345 LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1346 Col = insertOp(Instruction::Mul, Col, OuterR->Stride, Name, InsertBefore, DL);
1347 LLVM_DEBUG(dbgs() << " calculateIndex: " << *Col << "\n");
1348 Idx = insertOp(Instruction::Add, Idx, Col, Name, InsertBefore, DL);
1349 LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1350 }
1351 if (OuterR->Indirect) {
1352 Idx = insertOp(Instruction::Add, Idx, OuterR->Indirect,
1353 Name, InsertBefore, DL);
1354 LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1355 }
1356 CombinedR->Indirect = Idx;
1357 LLVM_DEBUG(dbgs() << " calculateIndex result(2d): CombinedR: " << *CombinedR << "\n");
1358 }
1359
1360 /***********************************************************************
1361 * insertOp : insert a binary op
1362 */
insertOp(Instruction::BinaryOps Opcode,Value * Lhs,unsigned Rhs,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)1363 Value *GenXRegionCollapsing::insertOp(Instruction::BinaryOps Opcode, Value *Lhs,
1364 unsigned Rhs, const Twine &Name, Instruction *InsertBefore, const DebugLoc &DL)
1365 {
1366 auto I16Ty = Type::getInt16Ty(InsertBefore->getContext());
1367 return insertOp(Opcode, Lhs,
1368 Constant::getIntegerValue(I16Ty, APInt(16, Rhs)),
1369 Name, InsertBefore, DL);
1370 }
1371
insertOp(Instruction::BinaryOps Opcode,Value * Lhs,Value * Rhs,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)1372 Value *GenXRegionCollapsing::insertOp(Instruction::BinaryOps Opcode, Value *Lhs,
1373 Value *Rhs, const Twine &Name, Instruction *InsertBefore, const DebugLoc &DL)
1374 {
1375 if (auto C = dyn_cast<ConstantInt>(Rhs)) {
1376 int RhsVal = C->getZExtValue();
1377 int LogVal = genx::exactLog2(RhsVal);
1378 if (LogVal >= 0) {
1379 switch (Opcode) {
1380 case Instruction::Mul:
1381 // multiply by power of 2 -> shl
1382 if (!LogVal)
1383 return Lhs;
1384 Rhs = Constant::getIntegerValue(C->getType(), APInt(16, LogVal));
1385 Opcode = Instruction::Shl;
1386 break;
1387 case Instruction::UDiv:
1388 // divide by power of 2 -> lshr
1389 if (!LogVal)
1390 return Lhs;
1391 Rhs = Constant::getIntegerValue(C->getType(), APInt(16, LogVal));
1392 Opcode = Instruction::LShr;
1393 break;
1394 case Instruction::URem:
1395 // remainder by power of 2 -> and
1396 Rhs = Constant::getIntegerValue(C->getType(), APInt(16, RhsVal - 1));
1397 Opcode = Instruction::And;
1398 break;
1399 default:
1400 break;
1401 }
1402 }
1403 }
1404 auto Inst = BinaryOperator::Create(Opcode, Lhs, Rhs, Name, InsertBefore);
1405 Inst->setDebugLoc(DL);
1406 return Inst;
1407 }
1408
isSingleElementRdRExtract(Instruction * I)1409 bool GenXRegionCollapsing::isSingleElementRdRExtract(Instruction *I) {
1410 if (!GenXIntrinsic::isRdRegion(I))
1411 return false;
1412 Region R = genx::makeRegionWithOffset(I, /*WantParentWidth=*/true);
1413 return R.NumElements == 1 && !R.Indirect;
1414 }
1415