1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2018-2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 //
10 // Implementation of methods for CMRegion class
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvmWrapper/IR/DerivedTypes.h"
15 #include "llvmWrapper/Support/TypeSize.h"
16 
17 #include "vc/Utils/GenX/Region.h"
18 #include "vc/Utils/GenX/TypeSize.h"
19 
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/Analysis/ConstantFolding.h"
22 #include "llvm/Analysis/TargetLibraryInfo.h"
23 #include "llvm/GenXIntrinsics/GenXIntrinsics.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/MathExtras.h"
32 #include "Probe/Assertion.h"
33 
34 using namespace llvm;
35 
36 // Find the datalayout if possible.
GetDL(Value * V)37 const DataLayout *GetDL(Value *V) {
38   if (auto Inst = dyn_cast_or_null<Instruction>(V))
39     return &Inst->getParent()->getParent()->getParent()->getDataLayout();
40   if (auto Arg = dyn_cast_or_null<Argument>(V))
41       return &Arg->getParent()->getParent()->getDataLayout();
42   return nullptr;
43 }
44 
45 /***********************************************************************
46  * Region constructor from a type
47  */
CMRegion(Type * Ty,const DataLayout * DL)48 CMRegion::CMRegion(Type *Ty, const DataLayout *DL)
49     : ElementBytes(0), ElementTy(0), NumElements(1), VStride(0), Width(1),
50       Stride(1), Offset(0), Indirect(0), IndirectIdx(0), IndirectAddrOffset(0),
51       Mask(0), ParentWidth(0)
52 {
53   IGC_ASSERT_MESSAGE(!Ty->isAggregateType(),
54     "cannot create region based on an aggregate type");
55   ElementTy = Ty;
56   if (IGCLLVM::FixedVectorType *VT =
57           dyn_cast<IGCLLVM::FixedVectorType>(ElementTy)) {
58     ElementTy = VT->getElementType();
59     NumElements = VT->getNumElements();
60     Width = NumElements;
61   }
62   if (DL) {
63     unsigned BitSize = DL->getTypeSizeInBits(ElementTy);
64     ElementBytes = alignTo<8>(BitSize) / 8;
65   } else {
66     unsigned BitSize = ElementTy->getPrimitiveSizeInBits();
67     ElementBytes = alignTo<8>(BitSize) / 8;
68     IGC_ASSERT_MESSAGE(ElementBytes,
69       "Cannot compute element size without data layout");
70   }
71 }
72 
73 /***********************************************************************
74  * Region constructor from a value
75  */
CMRegion(Value * V,const DataLayout * DL)76 CMRegion::CMRegion(Value *V, const DataLayout *DL)
77     : CMRegion(V->getType(), DL ? DL : GetDL(V)) {}
78 
79 /***********************************************************************
80  * Region constructor from a rd/wr region and its BaleInfo
81  * This also works with rdpredregion and wrpredregion, with Offset in
82  * bits rather than bytes, and with ElementBytes set to 1.
83  */
CMRegion(Instruction * Inst,bool WantParentWidth)84 CMRegion::CMRegion(Instruction *Inst, bool WantParentWidth)
85     : ElementBytes(0), ElementTy(0), NumElements(1), VStride(1), Width(1),
86       Stride(1), Offset(0), Indirect(0), IndirectIdx(0), IndirectAddrOffset(0),
87       Mask(0), ParentWidth(0)
88 {
89   // Determine where to get the subregion value from and which arg index
90   // the region parameters start at.
91   unsigned ArgIdx = 0;
92   Value *Subregion = 0;
93   IGC_ASSERT(isa<CallInst>(Inst));
94   switch (GenXIntrinsic::getGenXIntrinsicID(Inst)) {
95     case GenXIntrinsic::genx_rdpredregion:
96       NumElements =
97           cast<IGCLLVM::FixedVectorType>(Inst->getType())->getNumElements();
98       Width = NumElements;
99       Offset = cast<ConstantInt>(Inst->getOperand(1))->getZExtValue();
100       ElementBytes = 1;
101       return;
102     case GenXIntrinsic::genx_wrpredregion:
103       NumElements =
104           cast<IGCLLVM::FixedVectorType>(Inst->getOperand(1)->getType())
105               ->getNumElements();
106       Width = NumElements;
107       Offset = cast<ConstantInt>(Inst->getOperand(2))->getZExtValue();
108       ElementBytes = 1;
109       return;
110     case GenXIntrinsic::genx_rdregioni:
111     case GenXIntrinsic::genx_rdregionf:
112       ArgIdx = 1;
113       // The size/type of the region is given by the return value:
114       Subregion = Inst;
115       break;
116     case GenXIntrinsic::genx_wrregioni:
117     case GenXIntrinsic::genx_wrregionf:
118     case GenXIntrinsic::genx_wrconstregion:
119       ArgIdx = 2;
120       // The size/type of the region is given by the "subregion value to
121       // write" operand:
122       Subregion = Inst->getOperand(1);
123       // For wrregion, while we're here, also get the mask. We set mask to NULL
124       // if the mask operand is constant 1 (i.e. not predicated).
125       Mask = Inst->getOperand(GenXIntrinsic::GenXRegion::PredicateOperandNum);
126       if (auto C = dyn_cast<Constant>(Mask))
127         if (C->isAllOnesValue())
128           Mask = 0;
129       break;
130     default:
131       IGC_ASSERT(0);
132       break;
133   }
134   // Get the region parameters.
135   IGC_ASSERT(Subregion);
136   ElementTy = Subregion->getType();
137   if (IGCLLVM::FixedVectorType *VT =
138           dyn_cast<IGCLLVM::FixedVectorType>(ElementTy)) {
139     ElementTy = VT->getElementType();
140     NumElements = VT->getNumElements();
141   }
142   ElementBytes = ElementTy->getPrimitiveSizeInBits() / 8;
143   if (ElementTy->getPrimitiveSizeInBits())
144     ElementBytes = ElementBytes ? ElementBytes : 1;
145   VStride = cast<ConstantInt>(Inst->getOperand(ArgIdx))->getSExtValue();
146   Width = cast<ConstantInt>(Inst->getOperand(ArgIdx + 1))->getSExtValue();
147   Stride = cast<ConstantInt>(Inst->getOperand(ArgIdx + 2))->getSExtValue();
148   ArgIdx += 3;
149   // Get the start index.
150   Value *V = Inst->getOperand(ArgIdx);
151   IGC_ASSERT_MESSAGE(V->getType()->getScalarType()->isIntegerTy(16),
152     "region index must be i16 or vXi16 type");
153 
154   if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
155     Offset = CI->getSExtValue(); // Constant index.
156   else {
157     Indirect = V; // Index is variable; assume no baled in add.
158     // For a variable index, get the parent width arg.
159     ConstantInt *PW = dyn_cast<ConstantInt>(Inst->getOperand(ArgIdx + 1));
160     if (PW)
161       ParentWidth = PW->getZExtValue();
162   }
163   // We do some trivial legalization here. The legalization pass does not
164   // make these changes; instead we do them here so they are not permanently
165   // written back into the IR but are made on the fly each time some other
166   // pass uses this code to get the region info.
167   if (NumElements == 1) {
168     Width = Stride = 1;
169     VStride = 0;
170   } else {
171     if (NumElements <= Width) {
172       Width = NumElements;
173       VStride = 0;
174     } else if ((unsigned)VStride == Width * Stride) {
175       // VStride == Width * Stride, so we can canonicalize to a 1D region,
176       // but only if not indirect or not asked to preserve parentwidth,
177       // and never if multi-indirect.
178       if (!Indirect
179           || (!isa<VectorType>(Indirect->getType()) && !WantParentWidth)) {
180         Width = NumElements;
181         VStride = 0;
182         ParentWidth = 0;
183       }
184     } else if (Width == 1) {
185       // We can turn a 2D width 1 region into a 1D region, but if it is
186       // indirect it invalidates ParentWidth. So only do it if not asked
187       // to keep ParentWidth. Also we cannot do it if it is multi-indirect.
188       if (!Indirect
189           || (!isa<VectorType>(Indirect->getType()) && !WantParentWidth)) {
190         Width = NumElements;
191         Stride = VStride;
192         VStride = 0;
193         ParentWidth = 0;
194       }
195     }
196     if (Stride == 0 && Width == NumElements) {
197       // Canonical scalar region.
198       Width = 1;
199       VStride = 0;
200     }
201   }
202 }
203 
204 /***********************************************************************
205  * Region constructor from bitmap of which elements to set
206  *
207  * Enter:   Bits = bitmap of which elements to set
208  *          ElementBytes = bytes per element
209  *
210  * It is assumed that Bits represents a legal 1D region.
211  */
CMRegion(unsigned Bits,unsigned ElementBytes)212 CMRegion::CMRegion(unsigned Bits, unsigned ElementBytes)
213     : ElementBytes(ElementBytes), ElementTy(0), NumElements(1), VStride(1),
214       Width(1), Stride(1), Offset(0), Indirect(0), IndirectIdx(0),
215       IndirectAddrOffset(0), Mask(0), ParentWidth(0)
216 {
217   IGC_ASSERT(Bits);
218   Offset = countTrailingZeros(Bits, ZB_Undefined);
219   Bits >>= Offset;
220   Offset *= ElementBytes;
221   if (Bits != 1) {
222     Stride = countTrailingZeros(Bits & ~1, ZB_Undefined);
223     NumElements = Width = countPopulation(Bits);
224   }
225 }
226 
227 /***********************************************************************
228  * CMRegion::getSubregion : modify Region struct for a subregion
229  *
230  * Enter:   StartIdx = start index of subregion (in elements)
231  *          Size = size of subregion (in elements)
232  *
233  * This does not modify the Mask; the caller needs to do that separately.
234  */
getSubregion(unsigned StartIdx,unsigned Size)235 void CMRegion::getSubregion(unsigned StartIdx, unsigned Size)
236 {
237   if (Indirect && isa<VectorType>(Indirect->getType())) {
238     // Vector indirect (multi indirect). Set IndirectIdx to the index of
239     // the start element in the vector indirect.
240     IndirectIdx = StartIdx / Width;
241     StartIdx %= Width;
242   }
243   int AddOffset = StartIdx / Width * VStride;
244   AddOffset += StartIdx % Width * Stride;
245   AddOffset *= ElementBytes;
246   Offset += AddOffset;
247   if (!(StartIdx % Width) && !(Size % Width)) {
248     // StartIdx is at the start of a row and Size is a whole number of
249     // rows.
250   } else if (StartIdx % Width + Size > Width) {
251     // The subregion goes over a row boundary. This can only happen if there
252     // is only one row split and it is exactly in the middle.
253     VStride += (Size / 2 - Width) * Stride;
254     Width = Size / 2;
255   } else {
256     // Within a single row.
257     Width = Size;
258     VStride = Size * Stride;
259   }
260   NumElements = Size;
261 }
262 
263 /***********************************************************************
264  * CMRegion::createRdRegion : create rdregion intrinsic from "this" Region
265  *
266  * Enter:   Input = vector value to extract subregion from
267  *          Name = name for new instruction
268  *          InsertBefore = insert new inst before this point
269  *          DL = DebugLoc to give the new instruction
270  *          AllowScalar = true to return scalar if region is size 1
271  *
272  * Return:  newly created instruction
273  */
createRdRegion(Value * Input,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL,bool AllowScalar)274 Instruction *CMRegion::createRdRegion(Value *Input, const Twine &Name,
275     Instruction *InsertBefore, const DebugLoc &DL, bool AllowScalar)
276 {
277   IGC_ASSERT_MESSAGE(ElementBytes, "not expecting i1 element type");
278 
279   Value *StartIdx = getStartIdx(Name, InsertBefore, DL);
280   IntegerType *I32Ty = Type::getInt32Ty(Input->getContext());
281   Value *ParentWidthArg = UndefValue::get(I32Ty);
282   if (Indirect)
283     ParentWidthArg = ConstantInt::get(I32Ty, ParentWidth);
284   Value *Args[] = {   // Args to new rdregion:
285       Input, // input to original rdregion
286       ConstantInt::get(I32Ty, VStride), // vstride
287       ConstantInt::get(I32Ty, Width), // width
288       ConstantInt::get(I32Ty, Stride), // stride
289       StartIdx, // start index (in bytes)
290       ParentWidthArg // parent width (if variable start index)
291   };
292   Type *ElTy =
293       cast<IGCLLVM::FixedVectorType>(Args[0]->getType())->getElementType();
294   Type *RegionTy;
295   if (NumElements != 1 || !AllowScalar)
296     RegionTy = IGCLLVM::FixedVectorType::get(ElTy, NumElements);
297   else
298     RegionTy = ElTy;
299   Module *M = InsertBefore->getParent()->getParent()->getParent();
300   auto IID = ElTy->isFloatingPointTy()
301       ? GenXIntrinsic::genx_rdregionf : GenXIntrinsic::genx_rdregioni;
302   Function *Decl = getGenXRegionDeclaration(M, IID, RegionTy, Args);
303   Instruction *NewInst = CallInst::Create(Decl, Args, Name, InsertBefore);
304   NewInst->setDebugLoc(DL);
305   return NewInst;
306 }
307 
308 /***********************************************************************
309  * CMRegion::createWrRegion : create wrregion instruction for subregion
310  * CMRegion::createWrConstRegion : create wrconstregion instruction for subregion
311  *
312  * Enter:   OldVal = vector value to insert subregion into (can be undef)
313  *          Input = subregion value to insert (can be scalar, as long as
314  *                  region size is 1)
315  *          Name = name for new instruction
316  *          InsertBefore = insert new inst before this point
317  *          DL = DebugLoc to give any new instruction
318  *
319  * Return:  The new wrregion instruction. However, if it would have had a
320  *          predication mask of all 0s, it is omitted and OldVal is returned
321  *          instead.
322  */
createWrRegion(Value * OldVal,Value * Input,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)323 Instruction *CMRegion::createWrRegion(Value *OldVal, Value *Input,
324                                       const Twine &Name,
325                                       Instruction *InsertBefore,
326                                       const DebugLoc &DL) {
327   return createWrCommonRegion(OldVal->getType()->isFPOrFPVectorTy()
328         ? GenXIntrinsic::genx_wrregionf : GenXIntrinsic::genx_wrregioni,
329       OldVal, Input,
330       Name, InsertBefore, DL);
331 }
332 
createWrConstRegion(Value * OldVal,Value * Input,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)333 Instruction *CMRegion::createWrConstRegion(Value *OldVal, Value *Input,
334                                            const Twine &Name,
335                                            Instruction *InsertBefore,
336                                            const DebugLoc &DL) {
337   IGC_ASSERT(!Indirect);
338   IGC_ASSERT(!Mask);
339   IGC_ASSERT(isa<Constant>(Input));
340   return createWrCommonRegion(GenXIntrinsic::genx_wrconstregion, OldVal, Input,
341       Name, InsertBefore, DL);
342 }
343 
createWrCommonRegion(GenXIntrinsic::ID IID,Value * OldVal,Value * Input,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)344 Instruction *CMRegion::createWrCommonRegion(GenXIntrinsic::ID IID,
345                                             Value *OldVal, Value *Input,
346                                             const Twine &Name,
347                                             Instruction *InsertBefore,
348                                             const DebugLoc &DL) {
349   IGC_ASSERT_MESSAGE(ElementBytes, "not expecting i1 element type");
350   if (isa<VectorType>(Input->getType()))
351     IGC_ASSERT_MESSAGE(
352         NumElements ==
353             cast<IGCLLVM::FixedVectorType>(Input->getType())->getNumElements(),
354         "input value and region are inconsistent");
355   else
356     IGC_ASSERT_MESSAGE(NumElements == 1, "input value and region are inconsistent");
357   IGC_ASSERT_MESSAGE(OldVal->getType()->getScalarType() == Input->getType()->getScalarType(),
358     "scalar type mismatch");
359   Value *StartIdx = getStartIdx(Name, InsertBefore, DL);
360   IntegerType *I32Ty = Type::getInt32Ty(Input->getContext());
361   Value *ParentWidthArg = UndefValue::get(I32Ty);
362   if (Indirect)
363     ParentWidthArg = ConstantInt::get(I32Ty, ParentWidth);
364   // Get the mask value. If R.Mask is 0, then the wrregion is unpredicated
365   // and we just use constant 1.
366   Value *MaskArg = Mask;
367   if (!MaskArg)
368     MaskArg = ConstantInt::get(Type::getInt1Ty(Input->getContext()), 1);
369   // Build the wrregion.
370   Value *Args[] = {   // Args to new wrregion:
371       OldVal, // original vector
372       Input, // value to write into subregion
373       ConstantInt::get(I32Ty, VStride), // vstride
374       ConstantInt::get(I32Ty, Width), // width
375       ConstantInt::get(I32Ty, Stride), // stride
376       StartIdx, // start index (in bytes)
377       ParentWidthArg, // parent width (if variable start index)
378       MaskArg // mask
379   };
380   Module *M = InsertBefore->getParent()->getParent()->getParent();
381   Function *Decl = getGenXRegionDeclaration(M, IID, nullptr, Args);
382   Instruction *NewInst = CallInst::Create(Decl, Args, Name, InsertBefore);
383   NewInst->setDebugLoc(DL);
384   return NewInst;
385 }
386 
387 /***********************************************************************
388  * CMRegion::createRdPredRegion : create rdpredregion instruction
389  * CMRegion::createRdPredRegionOrConst : create rdpredregion instruction, or
390  *      simplify to constant
391  *
392  * Enter:   Input = vector value to extract subregion from
393  *          Index = start index of subregion
394  *          Size = size of subregion
395  *          Name = name for new instruction
396  *          InsertBefore = insert new inst before this point
397  *          DL = DebugLoc to give any new instruction
398  *
399  * Return:  The new rdpredregion instruction
400  *
401  * Unlike createRdRegion, this is a static method in Region, because you pass
402  * the region parameters (the start index and size) directly into this method.
403  */
createRdPredRegion(Value * Input,unsigned Index,unsigned Size,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)404 Instruction *CMRegion::createRdPredRegion(Value *Input, unsigned Index,
405     unsigned Size, const Twine &Name, Instruction *InsertBefore,
406     const DebugLoc &DL)
407 {
408   Type *I32Ty = Type::getInt32Ty(InsertBefore->getContext());
409   Value *Args[] = { // Args to new rdpredregion call:
410     Input, // input predicate
411     ConstantInt::get(I32Ty, Index) // start offset
412   };
413   auto RetTy =
414       IGCLLVM::FixedVectorType::get(Args[0]->getType()->getScalarType(), Size);
415   Module *M = InsertBefore->getParent()->getParent()->getParent();
416   Function *Decl = getGenXRegionDeclaration(M, GenXIntrinsic::genx_rdpredregion,
417       RetTy, Args);
418   Instruction *NewInst = CallInst::Create(Decl, Args, Name, InsertBefore);
419   NewInst->setDebugLoc(DL);
420   return NewInst;
421 }
422 
423 /***********************************************************************
424  * CMRegion::createRdVectorSplat: create vector splat using rdregion
425  *
426  * Enter:   NumElements = number of elements in the resulting vector
427  *          Input = Value from which splat is built
428  *          Name = name for new instruction
429  *          InsertBefore = insert new inst before this point
430  *          DL = DebugLoc to give any new instruction
431  *
432  * Return:  The new rdregion representing vector splat
433  */
createRdVectorSplat(const DataLayout & DL,unsigned NumElements,Value * SplattedValue,const Twine & Name,Instruction * InsertPt,const DebugLoc & DbgLoc)434 Value *CMRegion::createRdVectorSplat(const DataLayout &DL, unsigned NumElements,
435                                      Value *SplattedValue, const Twine &Name,
436                                      Instruction *InsertPt,
437                                      const DebugLoc &DbgLoc) {
438   auto *V1Cast = CastInst::Create(
439       Instruction::BitCast, SplattedValue,
440       IGCLLVM::FixedVectorType::get(SplattedValue->getType(), 1),
441       SplattedValue->getName() + ".v1cast", InsertPt);
442   V1Cast->setDebugLoc(DbgLoc);
443 
444   CMRegion R(V1Cast->getType(), &DL);
445   R.NumElements = NumElements;
446   R.Width = 1;
447   R.VStride = 0;
448   R.Stride = 0;
449   return R.createRdRegion(V1Cast, SplattedValue->getName() + ".splat", InsertPt,
450                           DbgLoc, false /* AllowScalar */);
451 }
452 
453 /***********************************************************************
454 * GetConstantSubvector : get a contiguous region from a vector constant
455 */
GetConstantSubvector(Constant * V,unsigned StartIdx,unsigned Size)456 static Constant *GetConstantSubvector(Constant *V,
457   unsigned StartIdx, unsigned Size)
458 {
459   Type *ElTy = cast<IGCLLVM::FixedVectorType>(V->getType())->getElementType();
460   Type *RegionTy = IGCLLVM::FixedVectorType::get(ElTy, Size);
461   if (isa<UndefValue>(V))
462     V = UndefValue::get(RegionTy);
463   else if (isa<ConstantAggregateZero>(V))
464     V = ConstantAggregateZero::get(RegionTy);
465   else {
466     SmallVector<Constant *, 32> Val;
467     for (unsigned i = 0; i != Size; ++i)
468       Val.push_back(V->getAggregateElement(i + StartIdx));
469     V = ConstantVector::get(Val);
470   }
471   return V;
472 }
473 
createRdPredRegionOrConst(Value * Input,unsigned Index,unsigned Size,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)474 Value *CMRegion::createRdPredRegionOrConst(Value *Input, unsigned Index,
475     unsigned Size, const Twine &Name, Instruction *InsertBefore,
476     const DebugLoc &DL)
477 {
478   if (auto C = dyn_cast<Constant>(Input))
479     return GetConstantSubvector(C, Index, Size);
480   return createRdPredRegion(Input, Index, Size, Name, InsertBefore, DL);
481 }
482 
483 /***********************************************************************
484  * CMRegion::createWrPredRegion : create wrpredregion instruction
485  *
486  * Enter:   OldVal = vector value to insert subregion into (can be undef)
487  *          Input = subregion value to insert
488  *          Index = start index of subregion
489  *          Name = name for new instruction
490  *          InsertBefore = insert new inst before this point
491  *          DL = DebugLoc to give any new instruction
492  *
493  * Return:  The new wrpredregion instruction
494  *
495  * Unlike createWrRegion, this is a static method in Region, because you pass
496  * the only region parameter (the start index) directly into this method.
497  */
createWrPredRegion(Value * OldVal,Value * Input,unsigned Index,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)498 Instruction *CMRegion::createWrPredRegion(Value *OldVal, Value *Input,
499     unsigned Index, const Twine &Name, Instruction *InsertBefore,
500     const DebugLoc &DL)
501 {
502   IntegerType *I32Ty = Type::getInt32Ty(Input->getContext());
503   Value *Args[] = {   // Args to new wrpredregion:
504       OldVal, // original vector
505       Input, // value to write into subregion
506       ConstantInt::get(I32Ty, Index), // start index
507   };
508   Module *M = InsertBefore->getParent()->getParent()->getParent();
509   Function *Decl = getGenXRegionDeclaration(M, GenXIntrinsic::genx_wrpredregion,
510       nullptr, Args);
511   Instruction *NewInst = CallInst::Create(Decl, Args, Name, InsertBefore);
512   NewInst->setDebugLoc(DL);
513   return NewInst;
514 }
515 
516 /***********************************************************************
517  * CMRegion::createWrPredPredRegion : create wrpredpredregion instruction
518  *
519  * Enter:   OldVal = vector value to insert subregion into (can be undef)
520  *          Input = subregion value to insert
521  *          Index = start index of subregion
522  *          Pred = predicate for the write region
523  *          Name = name for new instruction
524  *          InsertBefore = insert new inst before this point
525  *          DL = DebugLoc to give any new instruction
526  *
527  * Return:  The new wrpredpredregion instruction
528  *
529  * Unlike createWrRegion, this is a static method in Region, because you pass
530  * the only region parameter (the start index) directly into this method.
531  */
createWrPredPredRegion(Value * OldVal,Value * Input,unsigned Index,Value * Pred,const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)532 Instruction *CMRegion::createWrPredPredRegion(Value *OldVal, Value *Input,
533     unsigned Index, Value *Pred, const Twine &Name, Instruction *InsertBefore,
534     const DebugLoc &DL)
535 {
536   Type *Tys[] = { OldVal->getType(), Input->getType() };
537   Function *CalledFunc = GenXIntrinsic::getGenXDeclaration(
538       InsertBefore->getParent()->getParent()->getParent(),
539       GenXIntrinsic::genx_wrpredpredregion, Tys);
540   Value *Args[] = { OldVal, Input,
541       ConstantInt::get(Type::getInt32Ty(InsertBefore->getContext()), Index),
542       Pred };
543   auto NewInst = CallInst::Create(CalledFunc, Args, "", InsertBefore);
544   NewInst->setDebugLoc(DL);
545   return NewInst;
546 }
547 
548 /***********************************************************************
549  * setRegionCalledFunc : for an existing rdregion/wrregion call, modify
550  *      its called function to match its operand types
551  *
552  * This is used in GenXLegalization after modifying a wrregion operand
553  * such that its type changes. The called function then needs to change
554  * because it is decorated with overloaded types.
555  */
setRegionCalledFunc(Instruction * Inst)556 void CMRegion::setRegionCalledFunc(Instruction *Inst)
557 {
558   auto CI = cast<CallInst>(Inst);
559   SmallVector<Value *, 8> Opnds;
560   for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i)
561     Opnds.push_back(CI->getOperand(i));
562   Function *Decl = getGenXRegionDeclaration(
563       Inst->getParent()->getParent()->getParent(),
564       GenXIntrinsic::getGenXIntrinsicID(Inst),
565       Inst->getType(), Opnds);
566   CI->setOperand(CI->getNumArgOperands(), Decl);
567 }
568 
569 /***********************************************************************
570  * getRegionDeclaration : get the function declaration for a region intrinsic
571  *
572  * Enter:   M = Module
573  *          IID = intrinsic ID
574  *          RetTy = return type (can be 0 if return type not overloaded)
575  *          Args = array of operands so we can determine overloaded types
576  *
577  * Return:  the Function
578  */
getGenXRegionDeclaration(Module * M,GenXIntrinsic::ID IID,Type * RetTy,ArrayRef<Value * > Args)579 Function *CMRegion::getGenXRegionDeclaration(Module *M,
580     GenXIntrinsic::ID IID, Type *RetTy, ArrayRef<Value *> Args)
581 {
582   switch (IID) {
583     case GenXIntrinsic::genx_rdregioni:
584     case GenXIntrinsic::genx_rdregionf: {
585       Type *Tys[] = { RetTy, Args[0]->getType(), Args[4]->getType() };
586       return GenXIntrinsic::getGenXDeclaration(M, IID, Tys);
587     }
588     case GenXIntrinsic::genx_wrregioni:
589     case GenXIntrinsic::genx_wrregionf:
590     case GenXIntrinsic::genx_wrconstregion: {
591       Type *Tys[] = { Args[0]->getType(), Args[1]->getType(),
592           Args[5]->getType(), Args[7]->getType() };
593       return GenXIntrinsic::getGenXDeclaration(M, IID, Tys);
594     }
595     case GenXIntrinsic::genx_rdpredregion: {
596       Type *Tys[] = { RetTy, Args[0]->getType() };
597       return GenXIntrinsic::getGenXDeclaration(M, IID, Tys);
598     }
599     case GenXIntrinsic::genx_wrpredregion: {
600       Type *Tys[] = { Args[0]->getType(), Args[1]->getType() };
601       return GenXIntrinsic::getGenXDeclaration(M, IID, Tys);
602     }
603     default:
604       IGC_ASSERT_EXIT_MESSAGE(0, "unrecognized region intrinsic ID");
605   }
606   return nullptr;
607 }
608 
609 /***********************************************************************
610  * getStartIdx : get the LLVM IR Value for the start index of a region
611  *
612  * This is common code used by both createRdRegion and createWrRegion.
613  */
getStartIdx(const Twine & Name,Instruction * InsertBefore,const DebugLoc & DL)614 Value *CMRegion::getStartIdx(const Twine &Name, Instruction *InsertBefore,
615     const DebugLoc &DL)
616 {
617   IntegerType *I16Ty = Type::getInt16Ty(InsertBefore->getContext());
618   if (!Indirect)
619     return ConstantInt::get(I16Ty, Offset);
620   // Deal with indirect (variable index) region.
621   if (auto VT = dyn_cast<IGCLLVM::FixedVectorType>(Indirect->getType())) {
622     if (VT->getNumElements() != NumElements) {
623       // We have a vector indirect and we need to take a subregion of it.
624       CMRegion IdxRegion(Indirect);
625       IdxRegion.getSubregion(IndirectIdx, NumElements / Width);
626       Indirect = IdxRegion.createRdRegion(Indirect,
627           Name + ".multiindirect_idx_subregion", InsertBefore, DL);
628       IndirectIdx = 0;
629     }
630   }
631   Value *Index = Indirect;
632   if (Offset) {
633     Constant *OffsetVal = ConstantInt::get(I16Ty, Offset);
634     if (auto VT = dyn_cast<IGCLLVM::FixedVectorType>(Indirect->getType()))
635       OffsetVal = ConstantVector::getSplat(
636           IGCLLVM::getElementCount(VT->getNumElements()), OffsetVal);
637     auto BO = BinaryOperator::Create(Instruction::Add, Index, OffsetVal,
638         Name + ".indirect_idx_add", InsertBefore);
639     BO->setDebugLoc(DL);
640     Index = BO;
641   }
642   return Index;
643 }
644 
645 /***********************************************************************
646  * isSimilar : compare two regions to see if they have the same region
647  *      parameters other than start offset, also allowing element type to
648  *      be different
649  */
isSimilar(const CMRegion & R2) const650 bool CMRegion::isSimilar(const CMRegion &R2) const
651 {
652   if (ElementBytes == R2.ElementBytes)
653     return isStrictlySimilar(R2);
654   // Change the element type to match, so we can compare the regions.
655   CMRegion R = R2;
656   // FIXME: we need DataLayout here
657   if (!R.changeElementType(ElementTy, nullptr))
658     return false;
659   return isStrictlySimilar(R);
660 }
661 
getAccessBitMap(int MinTrackingOffset) const662 BitVector CMRegion::getAccessBitMap(int MinTrackingOffset) const {
663   // Construct bitmap for a single row
664   BitVector RowBitMap(getRowLength());
665   for (unsigned i = 0; i < Width; i++) {
666     RowBitMap <<= (Stride * ElementBytes);
667     RowBitMap.set(0, ElementBytes);
668   }
669   // Apply row bitmap to a whole region bitmap
670   // exactly NumRows times
671   BitVector BitMap(getLength());
672   unsigned NumRows = NumElements / Width;
673   if (NumRows != 1) {
674     for (unsigned i = 0; i < NumRows; i++) {
675       BitMap <<= (VStride * ElementBytes);
676       BitMap |= RowBitMap;
677     }
678   } else
679     BitMap = std::move(RowBitMap);
680   // Adjust mask according to min tracking
681   // offset for comparison
682   IGC_ASSERT(Offset >= MinTrackingOffset);
683   unsigned Diff = Offset - MinTrackingOffset;
684   if (Diff) {
685     BitMap.resize(BitMap.size() + Diff);
686     BitMap <<= Diff;
687   }
688   return BitMap;
689 }
690 
691 // overlap: Compare two regions to see whether they overlaps each other.
overlap(const CMRegion & R2) const692 bool CMRegion::overlap(const CMRegion &R2) const {
693   // To be conservative, if any of them is indirect, they overlaps.
694   if (Indirect || R2.Indirect)
695     return true;
696   // To be conservative, if different masks are used, they overlaps.
697   if (Mask != R2.Mask)
698     return true;
699   // Check offsets of regions for intersection
700   int MaxOffset = std::max(Offset, R2.Offset);
701   int MinEndOffset = std::min(Offset + getLength(), R2.Offset + R2.getLength());
702   if (MaxOffset > MinEndOffset)
703     return false;
704   // Check overlapping using bit masks
705   int MinOffset = std::min(Offset, R2.Offset);
706   BitVector Mask1 = getAccessBitMap(MinOffset);
707   BitVector Mask2 = R2.getAccessBitMap(MinOffset);
708   // If there are any common bits then these regions overlap
709   return Mask1.anyCommon(Mask2);
710 }
711 
712 /***********************************************************************
713  * CMRegion::isContiguous : test whether a region is contiguous
714  */
isContiguous() const715 bool CMRegion::isContiguous() const {
716   return (Width == 1 || Stride == 1) &&
717          (Width == NumElements || VStride == static_cast<int>(Width));
718 }
719 
720 /***********************************************************************
721  * CMRegion::isWhole : test whether a region covers exactly the whole of the
722  *      given type, allowing for the element type being different
723  */
isWhole(Type * Ty) const724 bool CMRegion::isWhole(Type *Ty) const
725 {
726   return isContiguous() && NumElements * ElementBytes * 8
727       == Ty->getPrimitiveSizeInBits();
728 }
729 
730 /***********************************************************************
731  * evaluateConstantRdRegion : evaluate rdregion with constant input
732  */
evaluateConstantRdRegion(Constant * Input,bool AllowScalar)733 Constant *CMRegion::evaluateConstantRdRegion(Constant *Input, bool AllowScalar)
734 {
735   IGC_ASSERT(!Indirect);
736   if (NumElements != 1)
737     AllowScalar = false;
738   if (Constant *SV = Input->getSplatValue()) {
739     if (AllowScalar)
740       return SV;
741     return ConstantVector::getSplat(IGCLLVM::getElementCount(NumElements), SV);
742   }
743   auto VT = cast<IGCLLVM::FixedVectorType>(Input->getType());
744   SmallVector<Constant *, 8> Values;
745   Constant *Undef = UndefValue::get(
746       AllowScalar ? ElementTy
747                   : IGCLLVM::FixedVectorType::get(ElementTy, NumElements));
748   if (isa<UndefValue>(Input))
749     return Undef;
750   unsigned RowIdx = Offset / ElementBytes;
751   unsigned Idx = RowIdx;
752   unsigned NextRow = Width;
753   for (unsigned i = 0; i != NumElements; ++i) {
754     if (i == NextRow) {
755       RowIdx += VStride;
756       Idx = RowIdx;
757     }
758     if (Idx >= VT->getNumElements())
759       return Undef; // out of range index
760     // Get the element value and push it into Values.
761     if (ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(Input))
762       Values.push_back(CDV->getElementAsConstant(Idx));
763     else {
764       auto CV = cast<ConstantVector>(Input);
765       Values.push_back(CV->getOperand(Idx));
766     }
767     Idx += Stride;
768   }
769   if (AllowScalar)
770     return Values[0];
771   return ConstantVector::get(Values);
772 }
773 
774 /***********************************************************************
775  * evaluateConstantWrRegion : evaluate wrregion with constant inputs
776  */
evaluateConstantWrRegion(Constant * OldVal,Constant * NewVal)777 Constant *CMRegion::evaluateConstantWrRegion(Constant *OldVal, Constant *NewVal)
778 {
779   IGC_ASSERT(!Indirect);
780   SmallVector<Constant *, 8> Vec;
781   for (unsigned i = 0, e = cast<IGCLLVM::FixedVectorType>(OldVal->getType())
782                                ->getNumElements();
783        i != e; ++i)
784     Vec.push_back(OldVal->getAggregateElement(i));
785   unsigned Off = Offset / ElementBytes, Row = Off;
786   auto NewVT = dyn_cast<IGCLLVM::FixedVectorType>(NewVal->getType());
787   unsigned NewNumEls = !NewVT ? 1 : NewVT->getNumElements();
788   for (unsigned i = 0;;) {
789     if (Off >= Vec.size())
790       return UndefValue::get(OldVal->getType()); // out of range
791     Vec[Off] = !NewVT ? NewVal : NewVal->getAggregateElement(i);
792     if (++i == NewNumEls)
793       break;
794     if (i % Width) {
795       Off += Stride;
796       continue;
797     }
798     Row += VStride;
799     Off = Row;
800   }
801   return ConstantVector::get(Vec);
802 }
803 
804 /***********************************************************************
805  * CMRegion::changeElementType : change element type of the region
806  *
807  * Return:  true if succeeded, false if failed (nothing altered)
808  */
changeElementType(Type * NewElementType,const DataLayout * DL)809 bool CMRegion::changeElementType(Type *NewElementType, const DataLayout *DL) {
810   // TODO: enable this assert once out codebase is ready
811   // IGC_ASSERT(DL);
812   IGC_ASSERT(ElementBytes);
813   IGC_ASSERT_MESSAGE(Offset % ElementBytes == 0, "Impossible offset (in bytes) for data type");
814 
815   unsigned NewElementBytes = vc::getTypeSize(NewElementType, DL).inBytes();
816 
817   if (NewElementBytes == ElementBytes) {
818     // No change in element size
819     ElementTy = NewElementType;
820     return true;
821   }
822 
823   unsigned Ratio = NewElementBytes / ElementBytes;
824   if (Ratio >= 1) {
825     IGC_ASSERT_MESSAGE(isPowerOf2_32(Ratio), "Ratio must be pow of 2");
826     // Trying to make the element size bigger.
827     if (Width & (Ratio - 1))
828       return false; // width misaligned
829     if (VStride & (Ratio - 1))
830       return false; // vstride misaligned
831     if (Stride != 1)
832       return false; // rows not contiguous
833     if (Offset & (Ratio - 1))
834       return false;
835     unsigned LogRatio = Log2_32(Ratio);
836     NumElements >>= LogRatio;
837     Width >>= LogRatio;
838     VStride >>= LogRatio;
839     if (Width == 1) {
840       // Width is now 1, so turn it into a 1D region.
841       Stride = VStride;
842       VStride = 0;
843       Width = NumElements;
844     }
845     ElementTy = NewElementType;
846     ElementBytes = NewElementBytes;
847     return true;
848   }
849   // Trying to make the element size smaller.
850   IGC_ASSERT(NewElementBytes);
851   Ratio = ElementBytes / NewElementBytes;
852   IGC_ASSERT_MESSAGE(isPowerOf2_32(Ratio), "Ratio must be pow of 2");
853   unsigned LogRatio = Log2_32(Ratio);
854   if (Stride == 1 || Width == 1) {
855     // Row contiguous.
856     Stride = 1;
857     NumElements <<= LogRatio;
858     Width <<= LogRatio;
859     VStride <<= LogRatio;
860     ElementTy = NewElementType;
861     ElementBytes = NewElementBytes;
862     return true;
863   }
864   if (!is2D()) {
865     // 1D and not contiguous. Turn it into a 2D region.
866     VStride = Stride << LogRatio;
867     Stride = 1;
868     Width = Ratio;
869     NumElements <<= LogRatio;
870     ElementTy = NewElementType;
871     ElementBytes = NewElementBytes;
872     return true;
873   }
874   return false;
875 }
876 
877 /***********************************************************************
878  * CMRegion::append : append region AR to this region
879  *
880  * Return:  true if succeeded (this region modified)
881  *          false if not possible to append (this region in indeterminate state)
882  *
883  * This succeeds even if it leaves this region in an illegal state where
884  * it has a non-integral number of rows. After doing a sequence of appends,
885  * the caller needs to check that the resulting region is legal by calling
886  * isWholeNumRows().
887  */
append(CMRegion AR)888 bool CMRegion::append(CMRegion AR)
889 {
890   IGC_ASSERT(AR.isWholeNumRows());
891   if (Indirect != AR.Indirect)
892     return false;
893   IGC_ASSERT(AR.Width);
894   unsigned ARNumRows = AR.NumElements / AR.Width;
895   // Consider each row of AR separately.
896   for (unsigned ARRow = 0; ARRow != ARNumRows;
897       ++ARRow, AR.Offset += AR.VStride * AR.ElementBytes) {
898     if (NumElements == Width) {
899       // This region is currently 1D.
900       if (NumElements == 1) {
901         IGC_ASSERT(ElementBytes);
902         Stride = (AR.Offset - Offset) / ElementBytes;
903       }
904       else if (AR.Width != 1 && Stride != AR.Stride)
905         return false; // Mismatched stride.
906       int NextOffset = Offset + Width * Stride * ElementBytes;
907       if (AR.Offset == NextOffset) {
908         // AR is a continuation of the same single row.
909         Width += AR.Width;
910         NumElements = Width;
911         continue;
912       }
913       // AR is the start (or whole) of a second row.
914       if (AR.Width > Width)
915         return false; // AR row is bigger than this row.
916       IGC_ASSERT(ElementBytes);
917       VStride = (AR.Offset - Offset) / ElementBytes;
918       NumElements += AR.Width;
919       continue;
920     }
921     // This region is already 2D.
922     IGC_ASSERT(Width);
923     unsigned ExtraBit = NumElements % Width;
924     int NextOffset = Offset + ((VStride * (NumElements / Width))
925         + ExtraBit) * ElementBytes;
926     if (NextOffset != AR.Offset)
927       return false; // Mismatched next offset.
928     if (AR.Width > Width - ExtraBit)
929       return false; // Too much to fill whole row, or remainder of row after
930                     //   existing extra bit.
931     if (AR.Width != 1 && AR.Stride != Stride)
932       return false; // Mismatched stride.
933     NumElements += AR.Width;
934   }
935   return true;
936 }
937 
938 /***********************************************************************
939  * Region debug dump/print
940  */
941 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
dump() const942 void CMRegion::dump() const
943 {
944   errs() << *this << "\n";
945 }
946 #endif
947 
print(raw_ostream & OS) const948 void CMRegion::print(raw_ostream &OS) const
949 {
950   OS << *IGCLLVM::FixedVectorType::get(ElementTy, NumElements) << " <"
951      << VStride << ";" << Width << "," << Stride << ">(";
952   if (Indirect) {
953     OS << Indirect->getName();
954     if (auto VT = dyn_cast<IGCLLVM::FixedVectorType>(Indirect->getType()))
955       OS << "<" << VT->getNumElements() << ">(" << IndirectIdx << ")";
956     OS << " + ";
957   }
958   OS << Offset << ")";
959   if (Indirect && ParentWidth)
960     OS << " {parentwidth=" << ParentWidth << "}";
961   if (Mask)
962     OS << " {mask=" << *Mask << "}";
963 }
964 
965