1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2017-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 //
10 /// GenXRegionCollapsing
11 /// --------------------
12 ///
13 /// GenX region collapsing pass is function pass that collapses nested
14 /// read regions or nested write regions.
15 ///
16 /// Nested region accesses can occur in two ways (or a mixture of both):
17 ///
18 /// 1. The front end compiler deliberately generates nested region access. The
19 ///    CM compiler does this for a matrix select, generating a region access for
20 ///    the rows and another one for the columns, safe in the knowledge that this
21 ///    pass will combine them where it can.
22 ///
23 /// 2. Two region accesses in different source code constructs (e.g. two select()
24 ///    calls, either in the same or different source statements).
25 ///
26 /// The combineRegions() function is what makes the decisions on whether two
27 /// regions can be collapsed, depending on whether they are 1D or 2D, how the
28 /// rows of one fit in the rows of the other, whether each is indirect, etc.
29 ///
30 /// This pass makes an effort to combine two region accesses even if there are
31 /// multiple bitcasts (from CM format()) or up to one SExt/ZExt (from a cast) in
32 /// between.
33 ///
34 //===----------------------------------------------------------------------===//
35 #define DEBUG_TYPE "GENX_RegionCollapsing"
36 
37 #include "GenX.h"
38 #include "GenXBaling.h"
39 #include "GenXUtil.h"
40 
41 #include "llvm/ADT/PostOrderIterator.h"
42 #include "llvm/Analysis/CFG.h"
43 #include "llvm/Analysis/InstructionSimplify.h"
44 #include "llvm/IR/BasicBlock.h"
45 #include "llvm/IR/Constants.h"
46 #include "llvm/IR/DataLayout.h"
47 #include "llvm/IR/Dominators.h"
48 #include "llvm/IR/Function.h"
49 #include "llvm/IR/Instructions.h"
50 #include "llvm/IR/Intrinsics.h"
51 #include "llvm/InitializePasses.h"
52 #include "llvm/Support/Debug.h"
53 #include "llvm/Transforms/Utils/Local.h"
54 #include "Probe/Assertion.h"
55 
56 #include "llvmWrapper/IR/DerivedTypes.h"
57 
58 using namespace llvm;
59 using namespace genx;
60 
61 namespace {
62 
63 // GenX region collapsing pass
64 class GenXRegionCollapsing : public FunctionPass {
65   const DataLayout *DL = nullptr;
66   DominatorTree *DT = nullptr;
67   bool Modified = false;
68 public:
69   static char ID;
GenXRegionCollapsing()70   explicit GenXRegionCollapsing() : FunctionPass(ID) { }
getPassName() const71   StringRef getPassName() const override { return "GenX Region Collapsing"; }
getAnalysisUsage(AnalysisUsage & AU) const72   void getAnalysisUsage(AnalysisUsage &AU) const override {
73     AU.addRequired<DominatorTreeWrapperPass>();
74     AU.setPreservesCFG();
75   }
76   bool runOnFunction(Function &F) override;
77 
78 private:
79   void runOnBasicBlock(BasicBlock *BB);
80   void processBitCast(BitCastInst *BC);
81   void processRdRegion(Instruction *InnerRd);
82   void splitReplicatingIndirectRdRegion(Instruction *Rd, Region *R);
83   void processWrRegionElim(Instruction *OuterWr);
84   Instruction *processWrRegionBitCast(Instruction *WrRegion);
85   void processWrRegionBitCast2(Instruction *WrRegion);
86   Instruction *processWrRegion(Instruction *OuterWr);
87   Instruction *processWrRegionSplat(Instruction *OuterWr);
88   bool normalizeElementType(Region *R1, Region *R2, bool PreferFirst = false);
89   bool combineRegions(const Region *OuterR, const Region *InnerR,
90                       Region *CombinedR);
91   void calculateIndex(const Region *OuterR, const Region *InnerR,
92                       Region *CombinedR, Value *InnerIndex, const Twine &Name,
93                       Instruction *InsertBefore, const DebugLoc &DL);
94   Value *insertOp(Instruction::BinaryOps Opcode, Value *Lhs, unsigned Rhs,
95                   const Twine &Name, Instruction *InsertBefore,
96                   const DebugLoc &DL);
97   Value *insertOp(Instruction::BinaryOps Opcode, Value *Lhs, Value *Rhs,
98                   const Twine &Name, Instruction *InsertBefore,
99                   const DebugLoc &DL);
100   bool isSingleElementRdRExtract(Instruction *I);
101 };
102 
103 }// end namespace llvm
104 
105 
106 char GenXRegionCollapsing::ID = 0;
107 namespace llvm { void initializeGenXRegionCollapsingPass(PassRegistry &); }
108 INITIALIZE_PASS_BEGIN(GenXRegionCollapsing, "GenXRegionCollapsing",
109                       "GenXRegionCollapsing", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)110 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
111 INITIALIZE_PASS_END(GenXRegionCollapsing, "GenXRegionCollapsing",
112                     "GenXRegionCollapsing", false, false)
113 
114 // Publicly exposed interface to pass...
115 FunctionPass *llvm::createGenXRegionCollapsingPass()
116 {
117   initializeGenXRegionCollapsingPass(*PassRegistry::getPassRegistry());
118   return new GenXRegionCollapsing();
119 }
120 
121 /***********************************************************************
122  * runOnFunction : run the region collapsing pass for this Function
123  */
runOnFunction(Function & F)124 bool GenXRegionCollapsing::runOnFunction(Function &F)
125 {
126   DL = &F.getParent()->getDataLayout();
127   DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
128 
129   // Track if there is any modification to the function.
130   bool Changed = false;
131 
132   // This does a postordered depth first traversal of the CFG, processing
133   // instructions within a basic block in reverse, to ensure that we see a def
134   // after its uses (ignoring phi node uses).
135   for (po_iterator<BasicBlock *> i = po_begin(&F.getEntryBlock()),
136                                  e = po_end(&F.getEntryBlock());
137        i != e; ++i) {
138     // Iterate until there is no modification.
139     BasicBlock *BB = *i;
140     do {
141       Modified = false;
142       runOnBasicBlock(BB);
143       if (Modified)
144         Changed = true;
145     } while (Modified);
146   }
147 
148   return Changed;
149 }
150 
lowerTrunc(TruncInst * Inst)151 static bool lowerTrunc(TruncInst *Inst) {
152   Value *InValue = Inst->getOperand(0);
153   if (!GenXIntrinsic::isRdRegion(InValue))
154     return false;
155 
156   Type *InElementTy = InValue->getType();
157   Type *OutElementTy = Inst->getType();
158   unsigned NumElements = 1;
159   if (auto *VT = dyn_cast<IGCLLVM::FixedVectorType>(InElementTy)) {
160     InElementTy = VT->getElementType();
161     OutElementTy = cast<VectorType>(OutElementTy)->getElementType();
162     NumElements = VT->getNumElements();
163   }
164   unsigned OutBitSize = OutElementTy->getPrimitiveSizeInBits();
165   IGC_ASSERT(OutBitSize);
166   // Do not touch truncations to i1 or vector of i1 types.
167   if (OutBitSize == 1)
168     return false;
169   unsigned Stride = InElementTy->getPrimitiveSizeInBits() / OutBitSize;
170 
171   // Create the new bitcast.
172   Instruction *BC = CastInst::Create(
173       Instruction::BitCast, InValue,
174       IGCLLVM::FixedVectorType::get(OutElementTy, Stride * NumElements),
175       Inst->getName(), Inst /*InsertBefore*/);
176   BC->setDebugLoc(Inst->getDebugLoc());
177 
178   // Create the new rdregion.
179   Region R(BC);
180   R.NumElements = NumElements;
181   R.Stride = Stride;
182   R.Width = NumElements;
183   R.VStride = R.Stride * R.Width;
184   Instruction *NewInst = R.createRdRegion(
185       BC, Inst->getName(), Inst /*InsertBefore*/, Inst->getDebugLoc(),
186       !isa<VectorType>(Inst->getType()) /*AllowScalar*/);
187 
188   // Change uses and mark the old inst for erasing.
189   Inst->replaceAllUsesWith(NewInst);
190   return true;
191 }
192 
runOnBasicBlock(BasicBlock * BB)193 void GenXRegionCollapsing::runOnBasicBlock(BasicBlock *BB) {
194   // Code simplification in block first.
195   for (auto BI = BB->begin(), E = --BB->end(); BI != E;) {
196     IGC_ASSERT(!BI->isTerminator());
197     Instruction *Inst = &*BI++;
198     if (Inst->use_empty())
199       continue;
200 
201     // Turn trunc into bitcast followed by rdr. This helps region collapsing in
202     // a later stage.
203     if (auto TI = dyn_cast<TruncInst>(Inst)) {
204       Modified |= lowerTrunc(TI);
205       continue;
206     }
207 
208     // Simplify
209     // %1 = call <1 x i32> @rdr(...)
210     // %2 = extractelement <1 x i32> %1, i32 0
211     // into
212     // %2 = call i32 @rdr(...)
213     //
214     if (auto EEI = dyn_cast<ExtractElementInst>(Inst)) {
215       Value *Src = EEI->getVectorOperand();
216       if (GenXIntrinsic::isRdRegion(Src) &&
217           cast<IGCLLVM::FixedVectorType>(Src->getType())->getNumElements() ==
218               1) {
219         // Create a new region with scalar output.
220         Region R(Inst);
221         Instruction *NewInst =
222             R.createRdRegion(Src, Inst->getName(), Inst /*InsertBefore*/,
223                              Inst->getDebugLoc(), true /*AllowScalar*/);
224         Inst->replaceAllUsesWith(NewInst);
225         Modified = true;
226         continue;
227       }
228     }
229 
230     if (Value *V = simplifyRegionInst(Inst, DL)) {
231       Inst->replaceAllUsesWith(V);
232       Modified = true;
233       continue;
234     }
235 
236     // sink index calculation before region collapsing. For collapsed regions,
237     // it is more difficult to lift constant offsets.
238     static const unsigned NOT_INDEX = 255;
239     unsigned Index = NOT_INDEX;
240 
241     unsigned IID = GenXIntrinsic::getGenXIntrinsicID(Inst);
242     if (GenXIntrinsic::isRdRegion(IID))
243       Index = GenXIntrinsic::GenXRegion::RdIndexOperandNum;
244     else if (GenXIntrinsic::isWrRegion(IID))
245       Index = GenXIntrinsic::GenXRegion::WrIndexOperandNum;
246     else if (isa<InsertElementInst>(Inst))
247       Index = 2;
248     else if (isa<ExtractElementInst>(Inst))
249       Index = 1;
250 
251     if (Index != NOT_INDEX) {
252       Use *U = &Inst->getOperandUse(Index);
253       Value *V = sinkAdd(*U);
254       if (V != U->get()) {
255         *U = V;
256         Modified = true;
257       }
258     }
259   }
260   Modified |= SimplifyInstructionsInBlock(BB);
261 
262   // This loop processes instructions in reverse, tolerating an instruction
263   // being removed during its processing, and not re-processing any new
264   // instructions added during the processing of an instruction.
265   for (Instruction *Prev = BB->getTerminator(); Prev;) {
266     Instruction *Inst = Prev;
267     Prev = nullptr;
268     if (Inst != &BB->front())
269       Prev = Inst->getPrevNode();
270     switch (GenXIntrinsic::getGenXIntrinsicID(Inst)) {
271     case GenXIntrinsic::genx_rdregioni:
272     case GenXIntrinsic::genx_rdregionf:
273       processRdRegion(Inst);
274       break;
275     case GenXIntrinsic::genx_wrregioni:
276     case GenXIntrinsic::genx_wrregionf:
277       processWrRegionElim(Inst);
278       if (!Inst->use_empty()) {
279         if (auto NewInst = processWrRegionBitCast(Inst)) {
280           Modified = true;
281           Inst = NewInst;
282         }
283         auto NewInst1 = processWrRegionSplat(Inst);
284         if (Inst != NewInst1) {
285           Modified = true;
286           Inst = NewInst1;
287         }
288 
289         auto NewInst = processWrRegion(Inst);
290         processWrRegionBitCast2(NewInst);
291         if (Inst != NewInst && NewInst->use_empty()) {
292           NewInst->eraseFromParent();
293           Modified = true;
294         }
295       }
296       if (Inst->use_empty()) {
297         Inst->eraseFromParent();
298         Modified = true;
299       }
300       break;
301     default:
302       if (auto BC = dyn_cast<BitCastInst>(Inst))
303         processBitCast(BC);
304       if (isa<CastInst>(Inst) && Inst->use_empty()) {
305         // Remove bitcast that has become unused due to changes in this pass.
306         Inst->eraseFromParent();
307         Modified = true;
308       }
309       break;
310     }
311   }
312 }
313 
314 /***********************************************************************
315  * createBitCast : create a bitcast, combining bitcasts where applicable
316  */
createBitCast(Value * Input,Type * Ty,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)317 static Value *createBitCast(Value *Input, Type *Ty, const Twine &Name,
318                             Instruction *InsertBefore, const DebugLoc &DL) {
319   if (Input->getType() == Ty)
320     return Input;
321   if (auto BC = dyn_cast<BitCastInst>(Input))
322     Input = BC->getOperand(0);
323   if (Input->getType() == Ty)
324     return Input;
325   auto NewBC = CastInst::Create(Instruction::BitCast, Input, Ty,
326       Name, InsertBefore);
327   NewBC->setDebugLoc(DL);
328   return NewBC;
329 }
330 
331 /***********************************************************************
332  * createBitCastToElementType : create a bitcast to a vector with the
333  *    specified element type, combining bitcasts where applicable
334  */
createBitCastToElementType(Value * Input,Type * ElementTy,const Twine & Name,Instruction * InsertBefore,const DataLayout * DL,const DebugLoc & DbgLoc)335 static Value *createBitCastToElementType(Value *Input, Type *ElementTy,
336                                          const Twine &Name,
337                                          Instruction *InsertBefore,
338                                          const DataLayout *DL,
339                                          const DebugLoc &DbgLoc) {
340   unsigned ElBytes = ElementTy->getPrimitiveSizeInBits() / 8U;
341   if (!ElBytes) {
342     IGC_ASSERT(ElementTy->isPointerTy());
343     IGC_ASSERT(ElementTy->getPointerElementType()->isFunctionTy());
344     ElBytes = DL->getTypeSizeInBits(ElementTy) / 8;
345   }
346   unsigned InputBytes = Input->getType()->getPrimitiveSizeInBits() / 8U;
347   if (!InputBytes) {
348     Type *T = Input->getType();
349     if (T->isVectorTy())
350       T = cast<VectorType>(T)->getElementType();
351     IGC_ASSERT(T->isPointerTy());
352     IGC_ASSERT(T->getPointerElementType()->isFunctionTy());
353     InputBytes = DL->getTypeSizeInBits(T) / 8;
354   }
355   IGC_ASSERT_MESSAGE(!(InputBytes & (ElBytes - 1)), "non-integral number of elements");
356   auto Ty = IGCLLVM::FixedVectorType::get(ElementTy, InputBytes / ElBytes);
357   return createBitCast(Input, Ty, Name, InsertBefore, DbgLoc);
358 }
359 
360 /***********************************************************************
361  * combineBitCastWithUser : if PossibleBC is a bitcast, and it has a single
362  *    user that is also a bitcast, then combine them
363  *
364  * If combined, the two bitcast instructions are erased.
365  *
366  * This can happen because combining two rdregions with a bitcast between
367  * them can result in the bitcast being used by another bitcast that was
368  * already there.
369  */
combineBitCastWithUser(Value * PossibleBC)370 static void combineBitCastWithUser(Value *PossibleBC)
371 {
372   if (auto BC1 = dyn_cast<BitCastInst>(PossibleBC)) {
373     if (BC1->hasOneUse()) {
374       if (auto BC2 = dyn_cast<BitCastInst>(BC1->use_begin()->getUser())) {
375         Value *CombinedBC = BC1->getOperand(0);
376         if (CombinedBC->getType() != BC2->getType())
377           CombinedBC = createBitCast(BC1->getOperand(0), BC2->getType(),
378               BC2->getName(), BC2, BC2->getDebugLoc());
379         BC2->replaceAllUsesWith(CombinedBC);
380         BC2->eraseFromParent();
381         BC1->eraseFromParent();
382       }
383     }
384   }
385 }
386 
387 /***********************************************************************
388  * processBitCast : process a bitcast whose input is rdregion
389  *
390  * We put the bitcast before the rdregion, in the hope that it will enable
391  * the rdregion to be baled in to something later on.
392  */
processBitCast(BitCastInst * BC)393 void GenXRegionCollapsing::processBitCast(BitCastInst *BC)
394 {
395   if (BC->getType()->getScalarType()->isIntegerTy(1))
396     return;
397   auto Rd = dyn_cast<Instruction>(BC->getOperand(0));
398 
399   // check if skipping this optimization.
400   auto skip = [=] {
401     // Skip if this is not rdregion
402     if (!Rd || !GenXIntrinsic::isRdRegion(Rd))
403       return true;
404 
405     Value *OldValue = Rd->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
406     if (GenXIntrinsic::isReadWritePredefReg(OldValue))
407       return true;
408 
409     // Single use, do optimization.
410     if (Rd->hasOneUse())
411       return false;
412 
413     // More than one uses, we check if rdr is reading from a global.
414     // If yes, still do such conversion, as bitcast could be folded into g_load.
415     while (auto CI = dyn_cast<BitCastInst>(OldValue))
416       OldValue = CI->getOperand(0);
417     auto LI = dyn_cast<LoadInst>(OldValue);
418     if (LI && getUnderlyingGlobalVariable(LI->getPointerOperand()))
419       return false;
420 
421     // skip otherwise;
422     return true;
423   };
424 
425   if (skip())
426     return;
427 
428   // skip call above shall check for RdRegion among other things
429   IGC_ASSERT(Rd);
430   IGC_ASSERT(GenXIntrinsic::isRdRegion(Rd));
431 
432   // We have a single use rdregion as the input to the bitcast.
433   // Adjust the region parameters if possible so the element type is that of
434   // the result of the bitcast, instead of the input.
435   Region ROrig = makeRegionFromBaleInfo(Rd, BaleInfo());
436   Region R = makeRegionFromBaleInfo(Rd, BaleInfo());
437   auto ElTy = BC->getType()->getScalarType();
438   IGC_ASSERT(DL);
439   if (!R.changeElementType(ElTy, DL))
440     return;
441 
442   // we do not want this optimization to be applied if resulting indirect
443   // region will have non-zero stride or non-single width
444   // this will require ineffective legalization in those cases
445   bool OrigCorr = ((ROrig.Width == 1) || (ROrig.Stride == 0));
446   bool ChangedWrong = ((R.Width != 1) && (R.Stride != 0));
447   if (OrigCorr && ChangedWrong && R.Indirect)
448     return;
449 
450   // Create the new bitcast.
451   IGC_ASSERT(vc::getTypeSize(ElTy, DL).inBits());
452   auto Input = Rd->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
453   auto NewBCTy = IGCLLVM::FixedVectorType::get(
454       ElTy, vc::getTypeSize(Input->getType(), DL).inBits() /
455                 vc::getTypeSize(ElTy, DL).inBits());
456   auto NewBC = CastInst::Create(Instruction::BitCast, Input, NewBCTy, "", Rd);
457   NewBC->takeName(BC);
458   NewBC->setDebugLoc(BC->getDebugLoc());
459   // Create the new rdregion.
460   auto NewRd = R.createRdRegion(NewBC, "", Rd, Rd->getDebugLoc(),
461       /*AllowScalar=*/!isa<VectorType>(BC->getType()));
462   NewRd->takeName(Rd);
463   // Replace uses.
464   BC->replaceAllUsesWith(NewRd);
465   // Caller removes BC.
466   Modified = true;
467 }
468 
469 /***********************************************************************
470  * processRdRegion : process a rdregion
471  *
472  * 1. If this rdregion is unused, it probably became so in the processing
473  *    of a later rdregion. Erase it.
474  *
475  * 2. Otherwise, see if the input to this rdregion is the result of
476  *    an earlier rdregion, and if so see if they can be combined. This can
477  *    work even if there are bitcasts and up to one sext/zext between the
478  *    two rdregions.
479  */
processRdRegion(Instruction * InnerRd)480 void GenXRegionCollapsing::processRdRegion(Instruction *InnerRd)
481 {
482   if (InnerRd->use_empty()) {
483     InnerRd->eraseFromParent();
484     Modified = true;
485     return;
486   }
487 
488   // We use genx::makeRegionWithOffset to get a Region object for a
489   // rdregion/wrregion throughout this pass, in order to ensure that, with an
490   // index that is V+const, we get the V and const separately
491   // (in Region::Indirect and Region::Offset).
492   // Then our index calculations can ensure that the constant add remains th
493   // last thing that happens in the calculation.
494   Region InnerR = genx::makeRegionWithOffset(InnerRd,
495                                              /*WantParentWidth=*/true);
496 
497   // Prevent region collapsing for specific src replication pattern,
498   // in order to enable swizzle optimization for Align16 instruction
499   if (InnerRd->hasOneUse()) {
500     if (auto UseInst = dyn_cast<Instruction>(InnerRd->use_begin()->getUser())) {
501       if (UseInst->getOpcode() == Instruction::FMul) {
502         auto NextInst = dyn_cast<Instruction>(UseInst->use_begin()->getUser());
503         if (NextInst &&
504             (NextInst->getOpcode() == Instruction::FAdd ||
505              NextInst->getOpcode() == Instruction::FSub) &&
506           InnerR.ElementTy->getPrimitiveSizeInBits() == 64U &&
507           InnerR.Width == 2 &&
508           InnerR.Stride == 0 &&
509           InnerR.VStride == 2)
510           return;
511       }
512     }
513   }
514 
515   for (;;) {
516     Instruction *OuterRd = dyn_cast<Instruction>(InnerRd->getOperand(0));
517     // Go through any bitcasts and up to one sext/zext if necessary to find the
518     // outer rdregion.
519     Instruction *Extend = nullptr;
520     bool HadElementTypeChange = false;
521     for (;;) {
522       if (!OuterRd)
523         break; // input not result of earlier rdregion
524       if (GenXIntrinsic::isRdRegion(OuterRd)) {
525         Value *OldValue =
526             OuterRd->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
527         // Do not optimize predefined register regions.
528         if (GenXIntrinsic::isReadWritePredefReg(OldValue))
529           OuterRd = nullptr;
530         break; // found the outer rdregion
531       }
532       if (isa<SExtInst>(OuterRd) || isa<ZExtInst>(OuterRd)) {
533         if (OuterRd->getOperand(0)->getType()->getScalarType()->isIntegerTy(1)) {
534           OuterRd = nullptr;
535           break; // input not result of earlier rdregion
536         }
537         if (Extend || HadElementTypeChange) {
538           OuterRd = nullptr;
539           break; // can only have one sext/zext between the rdregions, and
540                  // sext/zext not allowed if it is then subject to a bitcast
541                  // that changes the element type
542         }
543         // Remember the sext/zext instruction.
544         Extend = OuterRd;
545       } else if (isa<BitCastInst>(OuterRd)) {
546         if (OuterRd->getType()->getScalarType()
547             != OuterRd->getOperand(0)->getType()->getScalarType())
548           HadElementTypeChange = true;
549       } else {
550         OuterRd = nullptr;
551         break; // input not result of earlier rdregion
552       }
553       OuterRd = dyn_cast<Instruction>(OuterRd->getOperand(0));
554     }
555     if (!OuterRd)
556       break; // no outer rdregion that we can combine with
557     Region OuterR = genx::makeRegionWithOffset(OuterRd);
558     // There was a sext/zext. Because we are going to put that after the
559     // collapsed region, we want to modify the inner region to the
560     // extend's input element type without changing the region parameters
561     // (other than scaling the offset). We know that there is no element
562     // type changing bitcast between the extend and the inner rdregion.
563     if (Extend) {
564       if (InnerR.Indirect)
565         return; // cannot cope with indexed inner region and sext/zext
566       InnerR.ElementTy = Extend->getOperand(0)->getType()->getScalarType();
567       unsigned ExtInputElementBytes
568             = InnerR.ElementTy->getPrimitiveSizeInBits() / 8U;
569       InnerR.Offset = InnerR.Offset / InnerR.ElementBytes * ExtInputElementBytes;
570       InnerR.ElementBytes = ExtInputElementBytes;
571     }
572     // See if the regions can be combined. We call normalizeElementType with
573     // InnerR as the first arg so it prefers to normalize to that region's
574     // element type if possible. That can avoid a bitcast being put after the
575     // combined rdregion, which can help baling later on.
576     LLVM_DEBUG(dbgs() << "GenXRegionCollapsing::processRdRegion:\n"
577         "  OuterRd (line " << OuterRd->getDebugLoc().getLine() << "): " << *OuterRd << "\n"
578         "  InnerRd (line " << InnerRd->getDebugLoc().getLine() << "): " << *InnerRd << "\n");
579     if (!normalizeElementType(&InnerR, &OuterR, /*PreferFirst=*/true)) {
580       LLVM_DEBUG(dbgs() << "Cannot normalize element type\n");
581       return;
582     }
583 
584     // If it's a signle element extract from an indirect region
585     // then check if there exist some other extracts
586     if (OuterR.Indirect && (OuterR.NumElements != 1) &&
587         isSingleElementRdRExtract(InnerRd)) {
588       auto NumExtracts = llvm::count_if(OuterRd->uses(), [this](Use &U) {
589         return isSingleElementRdRExtract(cast<Instruction>(U.getUser()));
590       });
591       // If there are some more extracts except this one (InnerRd)
592       // then not combine these regions to prevent generation
593       // of extra address conversions for a combined region
594       if (NumExtracts > 1)
595         return;
596     }
597 
598     Region CombinedR;
599     if (!combineRegions(&OuterR, &InnerR, &CombinedR))
600       return; // cannot combine
601 
602     // If the combined region is both indirect and splat, then do not combine.
603     // Otherwise, this leads to an infinite loop as later on we split such
604     // region reads.
605     auto isIndirectSplat = [](const Region &R) {
606       if (!R.Indirect)
607         return false;
608       if (R.Width != R.NumElements && !R.VStride &&
609           !isa<VectorType>(R.Indirect->getType()))
610         return true;
611       if (R.Width == 1 || R.Stride)
612         return false;
613       return true;
614     };
615     if (isIndirectSplat(CombinedR))
616       return;
617 
618     // Calculate index if necessary.
619     if (InnerR.Indirect) {
620       calculateIndex(&OuterR, &InnerR, &CombinedR,
621           InnerRd->getOperand(GenXIntrinsic::GenXRegion::RdIndexOperandNum),
622           InnerRd->getName() + ".indexcollapsed",
623           InnerRd, InnerRd->getDebugLoc());
624     }
625     // If the element type of the combined region does not match that of the
626     // outer region, we need to do a bitcast first.
627     Value *Input = OuterRd->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
628     // InnerR.ElementTy not always equal to InnerRd->getType()->getScalarType() (look above)
629     if (InnerR.ElementTy != OuterRd->getType()->getScalarType())
630       Input = createBitCastToElementType(Input, InnerR.ElementTy,
631                                          Input->getName() +
632                                              ".bitcast_before_collapse",
633                                          OuterRd, DL, OuterRd->getDebugLoc());
634     // Create the combined rdregion.
635     Instruction *CombinedRd = CombinedR.createRdRegion(Input,
636         InnerRd->getName() + ".regioncollapsed", InnerRd, InnerRd->getDebugLoc(),
637         !isa<VectorType>(InnerRd->getType()));
638     // If we went through sext/zext, re-instate it here.
639     Value *NewVal = CombinedRd;
640     if (Extend) {
641       auto NewCI = CastInst::Create((Instruction::CastOps)Extend->getOpcode(),
642           NewVal, InnerRd->getType(), Extend->getName(), InnerRd);
643       NewCI->setDebugLoc(Extend->getDebugLoc());
644       NewVal = NewCI;
645     }
646     // If we still don't have the right type due to bitcasts in the original
647     // code, add a bitcast here.
648     NewVal = createBitCast(NewVal, InnerRd->getType(),
649         NewVal->getName() + ".bitcast_after_collapse", InnerRd,
650         InnerRd->getDebugLoc());
651     // Replace the inner read with the new value, and erase the inner read.
652     // any other instructions between it and the outer read (inclusive) that
653     // become unused.
654     InnerRd->replaceAllUsesWith(NewVal);
655     InnerRd->eraseFromParent();
656     Modified = true;
657     // Check whether we just created a bitcast that can be combined with its
658     // user. If so, combine them.
659     combineBitCastWithUser(NewVal);
660     InnerRd = CombinedRd;
661     InnerR = genx::makeRegionWithOffset(InnerRd, /*WantParentWidth=*/true);
662     // Because the loop in runOnFunction does not re-process the new rdregion,
663     // loop back here to re-process it.
664   }
665   // InnerRd and InnerR are now the combined rdregion (or the original one if
666   // no combining was done).
667   // Check whether we have a rdregion that is both indirect and replicating,
668   // that we want to split.
669   splitReplicatingIndirectRdRegion(InnerRd, &InnerR);
670 }
671 
672 /***********************************************************************
673  * splitReplicatingIndirectRdRegion : if the rdregion is both indirect and
674  *    replicating, split out the indirect part so it is read only once
675  */
splitReplicatingIndirectRdRegion(Instruction * Rd,Region * R)676 void GenXRegionCollapsing::splitReplicatingIndirectRdRegion(
677     Instruction *Rd, Region *R)
678 {
679   if (!R->Indirect)
680     return;
681   if (R->Width != R->NumElements && !R->VStride
682       && !isa<VectorType>(R->Indirect->getType())) {
683     // Replicating rows. We want an indirect region that just reads
684     // one row
685     Region IndirR = *R;
686     IndirR.NumElements = IndirR.Width;
687     auto Indir = IndirR.createRdRegion(Rd->getOperand(0),
688         Rd->getName() + ".split_replicated_indir", Rd, Rd->getDebugLoc());
689     // ... and a direct region that replicates the row.
690     R->Indirect = nullptr;
691     R->Offset = 0;
692     R->Stride = 1;
693     auto NewRd = R->createRdRegion(Indir, "", Rd, Rd->getDebugLoc());
694     NewRd->takeName(Rd);
695     Rd->replaceAllUsesWith(NewRd);
696     Rd->eraseFromParent();
697     Modified = true;
698     return;
699   }
700   if (R->Width == 1 || R->Stride)
701     return;
702   // Replicating columns. We want an indirect region that just reads
703   // one column
704   Region IndirR = *R;
705   IndirR.NumElements = IndirR.NumElements / IndirR.Width;
706   IndirR.Width = 1;
707   auto Indir = IndirR.createRdRegion(Rd->getOperand(0),
708       Rd->getName() + ".split_replicated_indir", Rd, Rd->getDebugLoc());
709   // ... and a direct region that replicates the column.
710   R->Indirect = nullptr;
711   R->Offset = 0;
712   R->VStride = 1;
713   auto NewRd = R->createRdRegion(Indir, "", Rd, Rd->getDebugLoc());
714   NewRd->takeName(Rd);
715   Rd->replaceAllUsesWith(NewRd);
716   Rd->eraseFromParent();
717 }
718 
719 /***********************************************************************
720  * processWrRegionElim : process a wrregion and eliminate redundant writes
721  *
722  * This detects the following code:
723  *
724  *   B = wrregion(A, V1, R)
725  *   C = wrregion(B, V2, R)
726  *
727  * (where "R" is a region that is identical in the two versions
728  * this can be collapsed to
729  *
730  *   D = wrregion(A, V2, R)
731  *
732  */
processWrRegionElim(Instruction * OuterWr)733 void GenXRegionCollapsing::processWrRegionElim(Instruction *OuterWr)
734 {
735   auto InnerWr = dyn_cast<Instruction>(
736       OuterWr->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum));
737   if (!GenXIntrinsic::isWrRegion(InnerWr))
738     return;
739   // Only perform this optimisation if the only use is with outer - otherwise
740   // this seems to make the code spill more
741   IGC_ASSERT(InnerWr);
742   if (!InnerWr->hasOneUse())
743     return;
744 
745   Region InnerR = genx::makeRegionFromBaleInfo(InnerWr, BaleInfo(),
746                                                /*WantParentWidth=*/true);
747   Region OuterR = genx::makeRegionFromBaleInfo(OuterWr, BaleInfo());
748   if (OuterR != InnerR)
749     return;
750   // Create the combined wrregion.
751   Instruction *CombinedWr = OuterR.createWrRegion(
752       InnerWr->getOperand(0),
753       OuterWr->getOperand(GenXIntrinsic::GenXRegion::NewValueOperandNum),
754       OuterWr->getName() + ".regioncollapsed", OuterWr, OuterWr->getDebugLoc());
755   OuterWr->replaceAllUsesWith(CombinedWr);
756   // Do not erase OuterWr here -- it gets erased by the caller.
757   Modified = true;
758 }
759 
760 /***********************************************************************
761  * processWrRegionBitCast : handle a wrregion whose "new value" is a
762  *      bitcast (before processing wrregion for region collapsing)
763  *
764  * Enter:   Inst = the wrregion
765  *
766  * Return:  replacement wrregion if any, else 0
767  *
768  * If the "new value" operand of the wrregion is a bitcast from scalar to
769  * 1-vector, or vice versa, then we can replace the wrregion with one that
770  * uses the input to the bitcast directly. This may enable later baling
771  * that would otherwise not happen.
772  *
773  * The bitcast typically arises from GenXLowering lowering an insertelement.
774  */
processWrRegionBitCast(Instruction * WrRegion)775 Instruction *GenXRegionCollapsing::processWrRegionBitCast(Instruction *WrRegion)
776 {
777   IGC_ASSERT(GenXIntrinsic::isWrRegion(WrRegion));
778   if (auto BC = dyn_cast<BitCastInst>(WrRegion->getOperand(
779           GenXIntrinsic::GenXRegion::NewValueOperandNum))) {
780     if (BC->getType()->getScalarType()
781         == BC->getOperand(0)->getType()->getScalarType()) {
782       // The bitcast is from scalar to 1-vector, or vice versa.
783       Region R = makeRegionFromBaleInfo(WrRegion, BaleInfo());
784       auto NewInst =
785           R.createWrRegion(WrRegion->getOperand(0), BC->getOperand(0), "",
786                            WrRegion, WrRegion->getDebugLoc());
787       NewInst->takeName(WrRegion);
788       WrRegion->replaceAllUsesWith(NewInst);
789       WrRegion->eraseFromParent();
790       return NewInst;
791     }
792   }
793   return nullptr;
794 }
795 
796 /***********************************************************************
797  * processWrRegionBitCast2 : handle a wrregion whose "new value" is a
798  *      bitcast (after processing wrregion for region collapsing)
799  *
800  * Enter:   WrRegion = the wrregion
801  *
802  * This does not erase WrRegion even if it becomes unused.
803  *
804  *
805  * If the "new value" operand of the wrregion is some other bitcast, then we
806  * change the wrregion to the pre-bitcast type and add new bitcasts for the
807  * "old value" input and the result. This makes it possible for the new value
808  * to be baled in to the wrregion.
809  */
processWrRegionBitCast2(Instruction * WrRegion)810 void GenXRegionCollapsing::processWrRegionBitCast2(Instruction *WrRegion)
811 {
812   auto BC = dyn_cast<BitCastInst>(WrRegion->getOperand(
813         GenXIntrinsic::GenXRegion::NewValueOperandNum));
814   if (!BC)
815     return;
816   Type *BCInputElementType = BC->getOperand(0)->getType()->getScalarType();
817   if (BCInputElementType->isIntegerTy(1))
818     return;
819 
820   Value *OldValue = WrRegion->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
821   if (GenXIntrinsic::isReadWritePredefReg(OldValue))
822     return;
823 
824   // Get the region params for the replacement wrregion, checking if that
825   // fails.
826   Region R = makeRegionFromBaleInfo(WrRegion, BaleInfo());
827   if (!R.changeElementType(BCInputElementType, DL))
828     return;
829   // Bitcast the "old value" input.
830   Value *OldVal = createBitCastToElementType(
831       OldValue,
832       BCInputElementType, WrRegion->getName() + ".precast", WrRegion, DL,
833       WrRegion->getDebugLoc());
834   // Create the replacement wrregion.
835   auto NewInst = R.createWrRegion(OldVal, BC->getOperand(0), "", WrRegion,
836                                   WrRegion->getDebugLoc());
837   NewInst->takeName(WrRegion);
838   // Cast it.
839   Value *Res = createBitCast(NewInst, WrRegion->getType(),
840       WrRegion->getName() + ".postcast", WrRegion, WrRegion->getDebugLoc());
841   WrRegion->replaceAllUsesWith(Res);
842 }
843 
844 // Check whether two values are bitwise identical.
isBitwiseIdentical(Value * V1,Value * V2,DominatorTree * DT)845 static bool isBitwiseIdentical(Value *V1, Value *V2, DominatorTree *DT) {
846   IGC_ASSERT_MESSAGE(V1, "null value");
847   IGC_ASSERT_MESSAGE(V2, "null value");
848   if (V1 == V2)
849     return true;
850   if (BitCastInst *BI = dyn_cast<BitCastInst>(V1))
851     V1 = BI->getOperand(0);
852   if (BitCastInst *BI = dyn_cast<BitCastInst>(V2))
853     V2 = BI->getOperand(0);
854 
855   // Special case arises from vload/vstore.
856   if (GenXIntrinsic::isVLoad(V1) && GenXIntrinsic::isVLoad(V2)) {
857     auto L1 = cast<CallInst>(V1);
858     auto L2 = cast<CallInst>(V2);
859 
860     // Loads from global variables.
861     auto GV1 = getUnderlyingGlobalVariable(L1->getOperand(0));
862     auto GV2 = getUnderlyingGlobalVariable(L2->getOperand(0));
863     Value *Addr = L1->getOperand(0);
864     if (GV1 && GV1 == GV2)
865       // OK.
866       Addr = GV1;
867     else if (L1->getOperand(0) != L2->getOperand(0))
868       // Check if loading from the same location.
869       return false;
870     else if (!isa<AllocaInst>(Addr))
871       // Check if this pointer is local and only used in vload/vstore.
872       return false;
873 
874     // Check if there is no store to the same location in between.
875     return !genx::hasMemoryDeps(L1, L2, Addr, DT);
876   }
877 
878   // Cannot prove.
879   return false;
880 }
881 
882 /***********************************************************************
883  * processWrRegion : process a wrregion
884  *
885  * Enter:   OuterWr = the wrregion instruction that we will attempt to use as
886  *                    the outer wrregion and collapse with inner ones
887  *
888  * Return:  the replacement wrregion if any, otherwise OuterWr
889  *
890  * This detects the following code:
891  *
892  *   B = rdregion(A, OuterR)
893  *   C = wrregion(B, V, InnerR)
894  *   D = wrregion(A, C, OuterR)
895  *
896  * (where "InnerR" and "OuterR" are the region parameters). This code can
897  * be collapsed to
898  *
899  *   D = wrregion(A, V, CombinedR)
900  *
901  * We want to do innermost wrregion combining first, but this pass visits
902  * instructions in the wrong order for that. So, when we see a wrregion
903  * here, we use recursion to scan back to find the innermost one and then work
904  * forwards to where we started.
905  */
processWrRegion(Instruction * OuterWr)906 Instruction *GenXRegionCollapsing::processWrRegion(Instruction *OuterWr)
907 {
908   IGC_ASSERT(OuterWr);
909   // Find the inner wrregion, skipping bitcasts.
910   auto InnerWr = dyn_cast<Instruction>(
911       OuterWr->getOperand(GenXIntrinsic::GenXRegion::NewValueOperandNum));
912   while (InnerWr && isa<BitCastInst>(InnerWr))
913     InnerWr = dyn_cast<Instruction>(InnerWr->getOperand(0));
914   if (!GenXIntrinsic::isWrRegion(InnerWr))
915     return OuterWr;
916   // Process inner wrregions first, recursively.
917   InnerWr = processWrRegion(InnerWr);
918   // Now process this one.
919   // Find the associated rdregion of the outer region, skipping bitcasts,
920   // and check it has the right region parameters.
921   IGC_ASSERT(InnerWr);
922   auto OuterRd = dyn_cast<Instruction>(InnerWr->getOperand(0));
923   while (OuterRd && isa<BitCastInst>(OuterRd))
924     OuterRd = dyn_cast<Instruction>(OuterRd->getOperand(0));
925   if (!GenXIntrinsic::isRdRegion(OuterRd))
926     return OuterWr;
927   IGC_ASSERT(OuterRd);
928   if (!isBitwiseIdentical(OuterRd->getOperand(0), OuterWr->getOperand(0), DT))
929     return OuterWr;
930   Region InnerR = genx::makeRegionWithOffset(InnerWr, /*WantParentWidth=*/true);
931   Region OuterR = genx::makeRegionWithOffset(OuterWr);
932   if (OuterR != genx::makeRegionWithOffset(OuterRd))
933     return OuterWr;
934   // See if the regions can be combined.
935   LLVM_DEBUG(dbgs() << "GenXRegionCollapsing::processWrRegion:\n"
936       "  OuterWr (line " << OuterWr->getDebugLoc().getLine() << "): " << *OuterWr << "\n"
937       "  InnerWr (line " << InnerWr->getDebugLoc().getLine() << "): " << *InnerWr << "\n");
938   if (!normalizeElementType(&OuterR, &InnerR)) {
939     LLVM_DEBUG(dbgs() << "Cannot normalize element type\n");
940     return OuterWr;
941   }
942   Region CombinedR;
943   if (!combineRegions(&OuterR, &InnerR, &CombinedR))
944     return OuterWr; // cannot combine
945   // Calculate index if necessary.
946   if (InnerR.Indirect) {
947     calculateIndex(&OuterR, &InnerR, &CombinedR,
948         InnerWr->getOperand(GenXIntrinsic::GenXRegion::WrIndexOperandNum),
949         InnerWr->getName() + ".indexcollapsed", OuterWr, InnerWr->getDebugLoc());
950   }
951   // Bitcast inputs if necessary.
952   Value *OldValInput = OuterRd->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
953   OldValInput = createBitCastToElementType(OldValInput, InnerR.ElementTy,
954       OldValInput->getName() + ".bitcast_before_collapse", OuterWr, DL, OuterWr->getDebugLoc());
955   Value *NewValInput = InnerWr->getOperand(GenXIntrinsic::GenXRegion::NewValueOperandNum);
956   NewValInput = createBitCastToElementType(NewValInput, InnerR.ElementTy,
957       NewValInput->getName() + ".bitcast_before_collapse", OuterWr, DL, OuterWr->getDebugLoc());
958   // Create the combined wrregion.
959   Instruction *CombinedWr = CombinedR.createWrRegion(
960       OldValInput, NewValInput, InnerWr->getName() + ".regioncollapsed",
961       OuterWr, InnerWr->getDebugLoc());
962   // Bitcast to the original type if necessary.
963   Value *Res = createBitCast(CombinedWr, OuterWr->getType(),
964       CombinedWr->getName() + ".cast", OuterWr,
965       InnerWr->getDebugLoc());
966   // Replace all uses.
967   OuterWr->replaceAllUsesWith(Res);
968   // Do not erase OuterWr here, as (if this function recursed to process an
969   // inner wrregion first) OuterWr might be the same as Prev in the loop in
970   // runOnFunction(). For a recursive call of processWrRegion, it will
971   // eventually get visited and then erased as it has no uses.  For an outer
972   // call of processWrRegion, OuterWr is erased by the caller.
973   Modified = true;
974   return CombinedWr;
975 }
976 
977 /***********************************************************************
978  * processWrRegionSplat : process a wrregion
979  *
980  * Enter:   OuterWr = the wrregion instruction that we will attempt to use as
981  *                    the outer wrregion and collapse with inner ones
982  *
983  * Return:  the replacement wrregion if any, otherwise OuterWr
984  *
985  * This detects the following code:
986  *
987  *   C = wrregion(undef, V, InnerR)
988  *   D = wrregion(undef, C, OuterR)
989  *
990  * (where "InnerR" and "OuterR" are the region parameters). This code can
991  * be collapsed to
992  *
993  *   D = wrregion(undef, V, CombinedR)
994  *
995  * We want to do innermost wrregion combining first, but this pass visits
996  * instructions in the wrong order for that. So, when we see a wrregion
997  * here, we use recursion to scan back to find the innermost one and then work
998  * forwards to where we started.
999  */
processWrRegionSplat(Instruction * OuterWr)1000 Instruction *GenXRegionCollapsing::processWrRegionSplat(Instruction *OuterWr)
1001 {
1002   IGC_ASSERT(OuterWr);
1003   // Find the inner wrregion, skipping bitcasts.
1004   auto InnerWr = dyn_cast<Instruction>(
1005       OuterWr->getOperand(GenXIntrinsic::GenXRegion::NewValueOperandNum));
1006   while (InnerWr && isa<BitCastInst>(InnerWr))
1007     InnerWr = dyn_cast<Instruction>(InnerWr->getOperand(0));
1008   if (!GenXIntrinsic::isWrRegion(InnerWr))
1009     return OuterWr;
1010   // Process inner wrregions first, recursively.
1011   InnerWr = processWrRegionSplat(InnerWr);
1012 
1013   // Now process this one.
1014   auto InnerSrc = dyn_cast<Constant>(InnerWr->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum));
1015   if (!InnerSrc)
1016     return OuterWr;
1017   // Ensure that the combined region is well-defined.
1018   if (InnerSrc->getType()->getScalarSizeInBits() !=
1019       OuterWr->getType()->getScalarSizeInBits())
1020     return OuterWr;
1021 
1022   auto OuterSrc = dyn_cast<Constant>(OuterWr->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum));
1023   if (!OuterSrc)
1024    return OuterWr;
1025   if (isa<UndefValue>(InnerSrc)) {
1026     // OK.
1027   } else {
1028     auto InnerSplat = InnerSrc->getSplatValue();
1029     auto OuterSplat = OuterSrc->getSplatValue();
1030     if (!InnerSplat || !OuterSplat || InnerSplat != OuterSplat)
1031       return OuterWr;
1032   }
1033 
1034   Region InnerR = genx::makeRegionWithOffset(InnerWr, /*WantParentWidth=*/true);
1035   Region OuterR = genx::makeRegionWithOffset(OuterWr);
1036   Region CombinedR;
1037   if (!combineRegions(&OuterR, &InnerR, &CombinedR))
1038     return OuterWr; // cannot combine
1039   // Calculate index if necessary.
1040   if (InnerR.Indirect) {
1041     calculateIndex(&OuterR, &InnerR, &CombinedR,
1042         InnerWr->getOperand(GenXIntrinsic::GenXRegion::WrIndexOperandNum),
1043         InnerWr->getName() + ".indexcollapsed", OuterWr, InnerWr->getDebugLoc());
1044   }
1045   // Bitcast inputs if necessary.
1046   Value *OldValInput = OuterSrc;
1047   Value *NewValInput = InnerWr->getOperand(1);
1048   NewValInput = createBitCastToElementType(NewValInput, OuterWr->getType()->getScalarType(),
1049       NewValInput->getName() + ".bitcast_before_collapse", OuterWr, DL, OuterWr->getDebugLoc());
1050   // Create the combined wrregion.
1051   Instruction *CombinedWr = CombinedR.createWrRegion(
1052       OldValInput, NewValInput, InnerWr->getName() + ".regioncollapsed",
1053       OuterWr, InnerWr->getDebugLoc());
1054   // Bitcast to the original type if necessary.
1055   Value *Res = createBitCast(CombinedWr, OuterWr->getType(),
1056       CombinedWr->getName() + ".cast", OuterWr,
1057       InnerWr->getDebugLoc());
1058   // Replace all uses.
1059   OuterWr->replaceAllUsesWith(Res);
1060   // Do not erase OuterWr here, as (if this function recursed to process an
1061   // inner wrregion first) OuterWr might be the same as Prev in the loop in
1062   // runOnFunction(). For a recursive call of processWrRegionSplat, it will
1063   // eventually get visited and then erased as it has no uses.  For an outer
1064   // call of processWrRegionSplat, OuterWr is erased by the caller.
1065   Modified = true;
1066   return CombinedWr;
1067 }
1068 
1069 /***********************************************************************
1070  * normalizeElementType : where two regions have different element size,
1071  *      make them the same if possible
1072  *
1073  * Enter:   R1 = first region
1074  *          R2 = second region
1075  *          PreferFirst = true to prefer the first region's element type
1076  *
1077  * Return:  false if failed
1078  *
1079  * If PreferFirst is false, this uses the larger element size if everything is
1080  * suitably aligned and the region with the smaller element size can be
1081  * converted to the larger element size.
1082  *
1083  * Otherwise, it uses the smaller element size if the region with the
1084  * larger element size can be converted to the smaller element size.
1085  */
normalizeElementType(Region * R1,Region * R2,bool PreferFirst)1086 bool GenXRegionCollapsing::normalizeElementType(Region *R1, Region *R2,
1087       bool PreferFirst)
1088 {
1089   if (R1->ElementBytes == R2->ElementBytes)
1090     return true; // nothing to do
1091   LLVM_DEBUG(dbgs() << "Before normalizeElementType:\n"
1092         "  R1: " << *R1 << "\n"
1093         "  R2: " << *R2 << "\n");
1094   // Set BigR to the region with the bigger element size, and SmallR to the
1095   // region with the smaller element size.
1096   bool PreferSmall = false;
1097   Region *BigR = nullptr, *SmallR = nullptr;
1098   if (R1->ElementBytes > R2->ElementBytes) {
1099     BigR = R1;
1100     SmallR = R2;
1101   } else {
1102     BigR = R2;
1103     SmallR = R1;
1104     PreferSmall = PreferFirst;
1105   }
1106   // Try the smaller element size first if it is preferred by the caller.
1107   if (PreferSmall)
1108     if (!BigR->Indirect) // big region not indirect
1109       if (BigR->changeElementType(SmallR->ElementTy, DL))
1110         return true;
1111   // Then try the bigger element size.
1112   if (!SmallR->Indirect) // small region not indirect
1113     if (SmallR->changeElementType(BigR->ElementTy, DL))
1114       return true;
1115   // Then try the smaller element size.
1116   if (!PreferSmall)
1117     if (!BigR->Indirect) // big region not indirect
1118       if (BigR->changeElementType(SmallR->ElementTy, DL))
1119         return true;
1120   return false;
1121 }
1122 
1123 /***********************************************************************
1124  * combineRegions : combine two regions if possible
1125  *
1126  * Enter:   OuterR = Region struct for outer region
1127  *          InnerR = Region struct for inner region
1128  *          CombinedR = Region struct to write combined region into
1129  *
1130  * Return:  true if combining is possible
1131  *
1132  * If combining is possible, this function sets up CombinedR. However,
1133  * CombinedR->Offset and CombinedR->Indirect are set assuming that the
1134  * inner region is direct.
1135  *
1136  * If OuterR->ElementTy != InnerR->ElementTy, this algo cannot determine
1137  * CombinedR->ElementTy, as the type depends on the order of respective
1138  * wr/rd regions (it should be the type of the last one).
1139  */
combineRegions(const Region * OuterR,const Region * InnerR,Region * CombinedR)1140 bool GenXRegionCollapsing::combineRegions(const Region *OuterR,
1141     const Region *InnerR, Region *CombinedR)
1142 {
1143   LLVM_DEBUG(dbgs() << "GenXRegionCollapsing::combineRegions\n"
1144       "  OuterR: " << *OuterR << "\n"
1145       "  InnerR: " << *InnerR << "\n");
1146   if (InnerR->Indirect && isa<VectorType>(InnerR->Indirect->getType()))
1147     return false; // multi indirect not supported
1148   if (OuterR->Indirect && isa<VectorType>(OuterR->Indirect->getType()))
1149     return false; // multi indirect not supported
1150   if (OuterR->Mask)
1151     return false; // outer region predicated, cannot combine
1152   *CombinedR = *InnerR;
1153   CombinedR->Indirect = OuterR->Indirect;
1154   CombinedR->Stride *= OuterR->Stride;
1155   CombinedR->VStride *= OuterR->Stride;
1156   unsigned ElOffset = InnerR->Offset / InnerR->ElementBytes;
1157   if (OuterR->is2D()) {
1158     // Outer region is 2D: create the combined offset. For outer 2D
1159     // and inner indirect, what CombinedR->Offset is set to here is
1160     // ignored and overwritten by calculateIndex(), so it does not matter
1161     // that it is incorrect in that case.
1162     ElOffset = ElOffset / OuterR->Width * OuterR->VStride
1163         + ElOffset % OuterR->Width * OuterR->Stride;
1164   } else {
1165     // Outer region is 1D: create the combined offset. For the benefit
1166     // of inner indirect, where InnerR->Offset is just an offset from
1167     // InnerR->Indirect, we cope with InnerR->Offset being apparently
1168     // out of range (negative or too big).
1169     ElOffset *= OuterR->Stride;
1170   }
1171   CombinedR->Offset = OuterR->Offset + ElOffset * InnerR->ElementBytes;
1172   if (!OuterR->is2D()) {
1173     LLVM_DEBUG(dbgs() << "outer 1D: CombinedR: " << *CombinedR << "\n");
1174     return true; // outer region is 1D, can always combine
1175   }
1176   if (InnerR->isScalar()) {
1177     LLVM_DEBUG(dbgs() << "inner scalar/splat: CombinedR: " << *CombinedR << "\n");
1178     return true; // inner region is scalar/splat, can always combine
1179   }
1180   if (InnerR->Indirect) {
1181     // Indirect inner region. Can combine as long as inner vstride is a
1182     // multiple of outer width, and it in turn is a multiple of inner parent
1183     // width.
1184     if (InnerR->ParentWidth && !(InnerR->VStride % (int)OuterR->Width)
1185         && !(OuterR->Width % InnerR->ParentWidth)) {
1186       CombinedR->VStride = InnerR->VStride / OuterR->Width * OuterR->VStride;
1187       LLVM_DEBUG(dbgs() << "inner indirect: CombinedR: " << *CombinedR << "\n");
1188       return true;
1189     }
1190     LLVM_DEBUG(dbgs() << "inner indirect: failed\n");
1191     return false;
1192   }
1193   // Inner region is not indirect.
1194   unsigned StartEl = InnerR->Offset / InnerR->ElementBytes;
1195   unsigned StartRow = StartEl / OuterR->Width;
1196   if (!InnerR->is2D()) {
1197     // Inner region is 1D but outer region is 2D.
1198     unsigned EndEl = StartEl + (InnerR->NumElements - 1) * InnerR->Stride;
1199     unsigned EndRow = EndEl / OuterR->Width;
1200     if (StartRow == EndRow) {
1201       // The whole 1D inner region fits in a row of the outer region.
1202       LLVM_DEBUG(dbgs() << "inner 1D outer 2D, fits in row: CombinedR: " << *CombinedR << "\n");
1203       return true;
1204     }
1205     if (EndRow == StartRow + 1 && !(InnerR->NumElements % 2)) {
1206       unsigned MidEl = StartEl + InnerR->NumElements / 2 * InnerR->Stride;
1207       if (InnerR->Stride > 0 && (unsigned)(MidEl - (EndRow * OuterR->Width))
1208             < (unsigned)InnerR->Stride) {
1209         // The 1D inner region is evenly split between two adjacent rows of
1210         // the outer region.
1211         CombinedR->VStride = (MidEl % OuterR->Width - StartEl % OuterR->Width)
1212             * OuterR->Stride + OuterR->VStride;
1213         CombinedR->Width = InnerR->NumElements / 2;
1214         LLVM_DEBUG(dbgs() << "inner 1D outer 2D, split between two rows: CombinedR: " << *CombinedR << "\n");
1215         return true;
1216       }
1217     }
1218     unsigned BeyondEndEl = EndEl + InnerR->Stride;
1219     if (BeyondEndEl % OuterR->Width == StartEl % OuterR->Width
1220         && !(OuterR->Width % InnerR->Stride)) {
1221       // The 1D inner region is evenly split between N adjacent rows of the
1222       // outer region, starting in the same column for each row.
1223       CombinedR->Width = OuterR->Width / InnerR->Stride;
1224       CombinedR->VStride = OuterR->VStride;
1225       LLVM_DEBUG(dbgs() << "inner 1D outer 2D, split between N rows: CombinedR: " << *CombinedR << "\n");
1226       return true;
1227     }
1228     LLVM_DEBUG(dbgs() << "inner 1D outer 2D, fail\n");
1229     return false; // All other 1D inner region cases fail.
1230   }
1231   if (!(InnerR->VStride % (int)OuterR->Width)) {
1232     // Inner vstride is a whole number of outer rows.
1233     CombinedR->VStride = OuterR->VStride * InnerR->VStride / (int)OuterR->Width;
1234     if (!InnerR->Indirect) {
1235       // For a direct inner region, calculate whether we can combine.
1236       unsigned StartEl = InnerR->Offset / InnerR->ElementBytes;
1237       unsigned StartRow = StartEl / OuterR->Width;
1238       unsigned EndRowOfFirstRow = (StartEl + (InnerR->Width - 1) * InnerR->Stride)
1239             / OuterR->Width;
1240       if (StartRow == EndRowOfFirstRow) {
1241         // Each row of inner region is within a row of outer region, starting
1242         // at the same column.
1243         LLVM_DEBUG(dbgs() << "row within row: CombinedR: " << *CombinedR << "\n");
1244         return true;
1245       }
1246     } else {
1247       // For an indirect inner region, use parent width to tell whether we can
1248       // combine.
1249       if (InnerR->ParentWidth && !(OuterR->Width % InnerR->ParentWidth)) {
1250         LLVM_DEBUG(dbgs() << "inner indirect, parentwidth ok: CombinedR: " << *CombinedR << "\n");
1251         return true;
1252       }
1253     }
1254   }
1255   // We could handle other cases like:
1256   //  - each row of inner region enclosed in a row of outer region
1257   //    but with a different column offset
1258   LLVM_DEBUG(dbgs() << "failed\n");
1259   return false;
1260 }
1261 
1262 /***********************************************************************
1263  * calculateIndex : calculate index in the case that the inner region is
1264  *      indirect
1265  *
1266  * Enter:   OuterR, InnerR = outer and inner regions
1267  *          CombinedR = combined region set up by combineRegions()
1268  *          InnerIndex = variable index for inner region, including the
1269  *              constant offset add that was extracted by the Region
1270  *              constructor into InnerR->Offset
1271  *          Name = name for new instruction(s)
1272  *          InsertBefore = insert before this instruction
1273  *          DL = debug loc for new instruction(s)
1274  *
1275  * This sets up CombinedR->Indirect and CombinedR->Offset.
1276  *
1277  * A Region has the offset set up as follows:
1278  *
1279  *  - For a direct region, R.Offset is the constant offset in bytes and
1280  *    R.Indirect is 0.
1281  *
1282  *  - Normally, for an indirect region, R.Offset is 0 and R.Indirect is the
1283  *    Value used for the offset (in bytes).
1284  *
1285  *  - But if the Value used for the offset is an add constant, then R.Offset
1286  *    is the constant offset and R.Indirect is the other operand of the add.
1287  *
1288  * In some code paths, this function needs the actual index of the inner region,
1289  * rather than the R.Offset and R.Indirect parts separated out by the Region
1290  * constructor. Thus it is passed InnerIndex, which is that actual index value.
1291  */
calculateIndex(const Region * OuterR,const Region * InnerR,Region * CombinedR,Value * InnerIndex,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)1292 void GenXRegionCollapsing::calculateIndex(const Region *OuterR,
1293     const Region *InnerR, Region *CombinedR, Value *InnerIndex,
1294     const Twine &Name, Instruction *InsertBefore, const DebugLoc &DL)
1295 {
1296   if (!OuterR->is2D()) {
1297     // Outer region is 1D. We can leave CombinedR->Offset as
1298     // set by combineRegions, but we need to add the indices together, scaling
1299     // the inner one by the outer region's stride.
1300     Value *Idx = InnerR->Indirect;
1301     if (OuterR->Stride != 1) {
1302       Idx = insertOp(Instruction::Mul, Idx, OuterR->Stride, Name,
1303           InsertBefore, DL);
1304       LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1305     }
1306     if (OuterR->Indirect) {
1307       Idx = insertOp(Instruction::Add, Idx, OuterR->Indirect, Name,
1308           InsertBefore, DL);
1309       LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1310     }
1311     CombinedR->Indirect = Idx;
1312     LLVM_DEBUG(dbgs() << " calculateIndex result(1d): CombinedR: " << *CombinedR << "\n");
1313     return;
1314   }
1315   // Outer region is 2D. We need to split the inner region's index into row
1316   // and column of the outer region, then recombine. We are using InnerIndex,
1317   // which includes any constant offset add, so we need to adjust
1318   // CombinedR->Offset so it does not include InnerR->Offset.
1319   CombinedR->Offset = OuterR->Offset;
1320   LLVM_DEBUG(dbgs() << " calculateIndex: Offset now " << CombinedR->Offset << "\n");
1321   Value *Col = insertOp(Instruction::URem, InnerIndex,
1322       OuterR->Width * OuterR->ElementBytes,
1323       Name, InsertBefore, DL);
1324   LLVM_DEBUG(dbgs() << " calculateIndex: " << *Col << "\n");
1325   Value *Row = insertOp(Instruction::UDiv, InnerIndex,
1326       OuterR->Width * OuterR->ElementBytes,
1327       Name, InsertBefore, DL);
1328   LLVM_DEBUG(dbgs() << " calculateIndex: " << *Row << "\n");
1329   Value *Idx = nullptr;
1330   if (!(OuterR->VStride % OuterR->Stride)) {
1331     // We need to multply Row by VStride and Col by Stride. However, Stride
1332     // divides VStride evenly, so we can common up the multiply by Stride.
1333     Idx = insertOp(Instruction::Mul, Row,
1334         OuterR->VStride * OuterR->ElementBytes / OuterR->Stride,
1335         Name, InsertBefore, DL);
1336     LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1337     Idx = insertOp(Instruction::Add, Idx, Col, Name, InsertBefore, DL);
1338     LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1339     Idx = insertOp(Instruction::Mul, Idx, OuterR->Stride, Name, InsertBefore, DL);
1340     LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1341   } else {
1342     // Need to do Row*VStride and Col*Stride separately.
1343     Idx = insertOp(Instruction::Mul, Row,
1344         OuterR->VStride * OuterR->ElementBytes, Name, InsertBefore, DL);
1345     LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1346     Col = insertOp(Instruction::Mul, Col, OuterR->Stride, Name, InsertBefore, DL);
1347     LLVM_DEBUG(dbgs() << " calculateIndex: " << *Col << "\n");
1348     Idx = insertOp(Instruction::Add, Idx, Col, Name, InsertBefore, DL);
1349     LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1350   }
1351   if (OuterR->Indirect) {
1352     Idx = insertOp(Instruction::Add, Idx, OuterR->Indirect,
1353         Name, InsertBefore, DL);
1354     LLVM_DEBUG(dbgs() << " calculateIndex: " << *Idx << "\n");
1355   }
1356   CombinedR->Indirect = Idx;
1357   LLVM_DEBUG(dbgs() << " calculateIndex result(2d): CombinedR: " << *CombinedR << "\n");
1358 }
1359 
1360 /***********************************************************************
1361  * insertOp : insert a binary op
1362  */
insertOp(Instruction::BinaryOps Opcode,Value * Lhs,unsigned Rhs,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)1363 Value *GenXRegionCollapsing::insertOp(Instruction::BinaryOps Opcode, Value *Lhs,
1364     unsigned Rhs, const Twine &Name, Instruction *InsertBefore, const DebugLoc &DL)
1365 {
1366   auto I16Ty = Type::getInt16Ty(InsertBefore->getContext());
1367   return insertOp(Opcode, Lhs,
1368       Constant::getIntegerValue(I16Ty, APInt(16, Rhs)),
1369       Name, InsertBefore, DL);
1370 }
1371 
insertOp(Instruction::BinaryOps Opcode,Value * Lhs,Value * Rhs,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)1372 Value *GenXRegionCollapsing::insertOp(Instruction::BinaryOps Opcode, Value *Lhs,
1373     Value *Rhs, const Twine &Name, Instruction *InsertBefore, const DebugLoc &DL)
1374 {
1375   if (auto C = dyn_cast<ConstantInt>(Rhs)) {
1376     int RhsVal = C->getZExtValue();
1377     int LogVal = genx::exactLog2(RhsVal);
1378     if (LogVal >= 0) {
1379       switch (Opcode) {
1380         case Instruction::Mul:
1381           // multiply by power of 2 -> shl
1382           if (!LogVal)
1383             return Lhs;
1384           Rhs = Constant::getIntegerValue(C->getType(), APInt(16, LogVal));
1385           Opcode = Instruction::Shl;
1386           break;
1387         case Instruction::UDiv:
1388           // divide by power of 2 -> lshr
1389           if (!LogVal)
1390             return Lhs;
1391           Rhs = Constant::getIntegerValue(C->getType(), APInt(16, LogVal));
1392           Opcode = Instruction::LShr;
1393           break;
1394         case Instruction::URem:
1395           // remainder by power of 2 -> and
1396           Rhs = Constant::getIntegerValue(C->getType(), APInt(16, RhsVal - 1));
1397           Opcode = Instruction::And;
1398           break;
1399         default:
1400           break;
1401       }
1402     }
1403   }
1404   auto Inst = BinaryOperator::Create(Opcode, Lhs, Rhs, Name, InsertBefore);
1405   Inst->setDebugLoc(DL);
1406   return Inst;
1407 }
1408 
isSingleElementRdRExtract(Instruction * I)1409 bool GenXRegionCollapsing::isSingleElementRdRExtract(Instruction *I) {
1410   if (!GenXIntrinsic::isRdRegion(I))
1411     return false;
1412   Region R = genx::makeRegionWithOffset(I, /*WantParentWidth=*/true);
1413   return R.NumElements == 1 && !R.Indirect;
1414 }
1415