1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass custom lowers llvm.gather and llvm.scatter instructions to
10 // RISCV intrinsics.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "RISCV.h"
15 #include "RISCVTargetMachine.h"
16 #include "llvm/Analysis/LoopInfo.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/Analysis/VectorUtils.h"
19 #include "llvm/CodeGen/TargetPassConfig.h"
20 #include "llvm/IR/GetElementPtrTypeIterator.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/IntrinsicsRISCV.h"
24 #include "llvm/Transforms/Utils/Local.h"
25 
26 using namespace llvm;
27 
28 #define DEBUG_TYPE "riscv-gather-scatter-lowering"
29 
30 namespace {
31 
32 class RISCVGatherScatterLowering : public FunctionPass {
33   const RISCVSubtarget *ST = nullptr;
34   const RISCVTargetLowering *TLI = nullptr;
35   LoopInfo *LI = nullptr;
36   const DataLayout *DL = nullptr;
37 
38   SmallVector<WeakTrackingVH> MaybeDeadPHIs;
39 
40 public:
41   static char ID; // Pass identification, replacement for typeid
42 
43   RISCVGatherScatterLowering() : FunctionPass(ID) {}
44 
45   bool runOnFunction(Function &F) override;
46 
47   void getAnalysisUsage(AnalysisUsage &AU) const override {
48     AU.setPreservesCFG();
49     AU.addRequired<TargetPassConfig>();
50     AU.addRequired<LoopInfoWrapperPass>();
51   }
52 
53   StringRef getPassName() const override {
54     return "RISCV gather/scatter lowering";
55   }
56 
57 private:
58   bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp);
59 
60   bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
61                                  Value *AlignOp);
62 
63   std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
64                                                      IRBuilder<> &Builder);
65 
66   bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
67                               PHINode *&BasePtr, BinaryOperator *&Inc,
68                               IRBuilder<> &Builder);
69 };
70 
71 } // end anonymous namespace
72 
73 char RISCVGatherScatterLowering::ID = 0;
74 
75 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
76                 "RISCV gather/scatter lowering pass", false, false)
77 
78 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
79   return new RISCVGatherScatterLowering();
80 }
81 
82 bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType,
83                                                          Value *AlignOp) {
84   Type *ScalarType = DataType->getScalarType();
85   if (!TLI->isLegalElementTypeForRVV(ScalarType))
86     return false;
87 
88   MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
89   if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize())
90     return false;
91 
92   // FIXME: Let the backend type legalize by splitting/widening?
93   EVT DataVT = TLI->getValueType(*DL, DataType);
94   if (!TLI->isTypeLegal(DataVT))
95     return false;
96 
97   return true;
98 }
99 
100 // TODO: Should we consider the mask when looking for a stride?
101 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
102   unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
103 
104   // Check that the start value is a strided constant.
105   auto *StartVal =
106       dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
107   if (!StartVal)
108     return std::make_pair(nullptr, nullptr);
109   APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
110   ConstantInt *Prev = StartVal;
111   for (unsigned i = 1; i != NumElts; ++i) {
112     auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
113     if (!C)
114       return std::make_pair(nullptr, nullptr);
115 
116     APInt LocalStride = C->getValue() - Prev->getValue();
117     if (i == 1)
118       StrideVal = LocalStride;
119     else if (StrideVal != LocalStride)
120       return std::make_pair(nullptr, nullptr);
121 
122     Prev = C;
123   }
124 
125   Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
126 
127   return std::make_pair(StartVal, Stride);
128 }
129 
130 static std::pair<Value *, Value *> matchStridedStart(Value *Start,
131                                                      IRBuilder<> &Builder) {
132   // Base case, start is a strided constant.
133   auto *StartC = dyn_cast<Constant>(Start);
134   if (StartC)
135     return matchStridedConstant(StartC);
136 
137   // Not a constant, maybe it's a strided constant with a splat added to it.
138   auto *BO = dyn_cast<BinaryOperator>(Start);
139   if (!BO || BO->getOpcode() != Instruction::Add)
140     return std::make_pair(nullptr, nullptr);
141 
142   // Look for an operand that is splatted.
143   unsigned OtherIndex = 1;
144   Value *Splat = getSplatValue(BO->getOperand(0));
145   if (!Splat) {
146     Splat = getSplatValue(BO->getOperand(1));
147     OtherIndex = 0;
148   }
149   if (!Splat)
150     return std::make_pair(nullptr, nullptr);
151 
152   Value *Stride;
153   std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
154                                               Builder);
155   if (!Start)
156     return std::make_pair(nullptr, nullptr);
157 
158   // Add the splat value to the start.
159   Builder.SetInsertPoint(BO);
160   Builder.SetCurrentDebugLocation(DebugLoc());
161   Start = Builder.CreateAdd(Start, Splat);
162   return std::make_pair(Start, Stride);
163 }
164 
165 // Recursively, walk about the use-def chain until we find a Phi with a strided
166 // start value. Build and update a scalar recurrence as we unwind the recursion.
167 // We also update the Stride as we unwind. Our goal is to move all of the
168 // arithmetic out of the loop.
169 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
170                                                         Value *&Stride,
171                                                         PHINode *&BasePtr,
172                                                         BinaryOperator *&Inc,
173                                                         IRBuilder<> &Builder) {
174   // Our base case is a Phi.
175   if (auto *Phi = dyn_cast<PHINode>(Index)) {
176     // A phi node we want to perform this function on should be from the
177     // loop header.
178     if (Phi->getParent() != L->getHeader())
179       return false;
180 
181     Value *Step, *Start;
182     if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
183         Inc->getOpcode() != Instruction::Add)
184       return false;
185     assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
186     unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
187     assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
188            "Expected one operand of phi to be Inc");
189 
190     // Only proceed if the step is loop invariant.
191     if (!L->isLoopInvariant(Step))
192       return false;
193 
194     // Step should be a splat.
195     Step = getSplatValue(Step);
196     if (!Step)
197       return false;
198 
199     std::tie(Start, Stride) = matchStridedStart(Start, Builder);
200     if (!Start)
201       return false;
202     assert(Stride != nullptr);
203 
204     // Build scalar phi and increment.
205     BasePtr =
206         PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
207     Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
208                                     Inc);
209     BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
210     BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
211 
212     // Note that this Phi might be eligible for removal.
213     MaybeDeadPHIs.push_back(Phi);
214     return true;
215   }
216 
217   // Otherwise look for binary operator.
218   auto *BO = dyn_cast<BinaryOperator>(Index);
219   if (!BO)
220     return false;
221 
222   if (BO->getOpcode() != Instruction::Add &&
223       BO->getOpcode() != Instruction::Or &&
224       BO->getOpcode() != Instruction::Mul &&
225       BO->getOpcode() != Instruction::Shl)
226     return false;
227 
228   // Only support shift by constant.
229   if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
230     return false;
231 
232   // We need to be able to treat Or as Add.
233   if (BO->getOpcode() == Instruction::Or &&
234       !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
235     return false;
236 
237   // We should have one operand in the loop and one splat.
238   Value *OtherOp;
239   if (isa<Instruction>(BO->getOperand(0)) &&
240       L->contains(cast<Instruction>(BO->getOperand(0)))) {
241     Index = cast<Instruction>(BO->getOperand(0));
242     OtherOp = BO->getOperand(1);
243   } else if (isa<Instruction>(BO->getOperand(1)) &&
244              L->contains(cast<Instruction>(BO->getOperand(1)))) {
245     Index = cast<Instruction>(BO->getOperand(1));
246     OtherOp = BO->getOperand(0);
247   } else {
248     return false;
249   }
250 
251   // Make sure other op is loop invariant.
252   if (!L->isLoopInvariant(OtherOp))
253     return false;
254 
255   // Make sure we have a splat.
256   Value *SplatOp = getSplatValue(OtherOp);
257   if (!SplatOp)
258     return false;
259 
260   // Recurse up the use-def chain.
261   if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
262     return false;
263 
264   // Locate the Step and Start values from the recurrence.
265   unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
266   unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
267   Value *Step = Inc->getOperand(StepIndex);
268   Value *Start = BasePtr->getOperand(StartBlock);
269 
270   // We need to adjust the start value in the preheader.
271   Builder.SetInsertPoint(
272       BasePtr->getIncomingBlock(StartBlock)->getTerminator());
273   Builder.SetCurrentDebugLocation(DebugLoc());
274 
275   switch (BO->getOpcode()) {
276   default:
277     llvm_unreachable("Unexpected opcode!");
278   case Instruction::Add:
279   case Instruction::Or: {
280     // An add only affects the start value. It's ok to do this for Or because
281     // we already checked that there are no common set bits.
282 
283     // If the start value is Zero, just take the SplatOp.
284     if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero())
285       Start = SplatOp;
286     else
287       Start = Builder.CreateAdd(Start, SplatOp, "start");
288     BasePtr->setIncomingValue(StartBlock, Start);
289     break;
290   }
291   case Instruction::Mul: {
292     // If the start is zero we don't need to multiply.
293     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
294       Start = Builder.CreateMul(Start, SplatOp, "start");
295 
296     Step = Builder.CreateMul(Step, SplatOp, "step");
297 
298     // If the Stride is 1 just take the SplatOpt.
299     if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
300       Stride = SplatOp;
301     else
302       Stride = Builder.CreateMul(Stride, SplatOp, "stride");
303     Inc->setOperand(StepIndex, Step);
304     BasePtr->setIncomingValue(StartBlock, Start);
305     break;
306   }
307   case Instruction::Shl: {
308     // If the start is zero we don't need to shift.
309     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
310       Start = Builder.CreateShl(Start, SplatOp, "start");
311     Step = Builder.CreateShl(Step, SplatOp, "step");
312     Stride = Builder.CreateShl(Stride, SplatOp, "stride");
313     Inc->setOperand(StepIndex, Step);
314     BasePtr->setIncomingValue(StartBlock, Start);
315     break;
316   }
317   }
318 
319   return true;
320 }
321 
322 std::pair<Value *, Value *>
323 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
324                                                    IRBuilder<> &Builder) {
325 
326   SmallVector<Value *, 2> Ops(GEP->operands());
327 
328   // Base pointer needs to be a scalar.
329   if (Ops[0]->getType()->isVectorTy())
330     return std::make_pair(nullptr, nullptr);
331 
332   // Make sure we're in a loop and it is in loop simplify form.
333   Loop *L = LI->getLoopFor(GEP->getParent());
334   if (!L || !L->isLoopSimplifyForm())
335     return std::make_pair(nullptr, nullptr);
336 
337   Optional<unsigned> VecOperand;
338   unsigned TypeScale = 0;
339 
340   // Look for a vector operand and scale.
341   gep_type_iterator GTI = gep_type_begin(GEP);
342   for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
343     if (!Ops[i]->getType()->isVectorTy())
344       continue;
345 
346     if (VecOperand)
347       return std::make_pair(nullptr, nullptr);
348 
349     VecOperand = i;
350 
351     TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
352     if (TS.isScalable())
353       return std::make_pair(nullptr, nullptr);
354 
355     TypeScale = TS.getFixedSize();
356   }
357 
358   // We need to find a vector index to simplify.
359   if (!VecOperand)
360     return std::make_pair(nullptr, nullptr);
361 
362   // We can't extract the stride if the arithmetic is done at a different size
363   // than the pointer type. Adding the stride later may not wrap correctly.
364   // Technically we could handle wider indices, but I don't expect that in
365   // practice.
366   Value *VecIndex = Ops[*VecOperand];
367   Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
368   if (VecIndex->getType() != VecIntPtrTy)
369     return std::make_pair(nullptr, nullptr);
370 
371   Value *Stride;
372   BinaryOperator *Inc;
373   PHINode *BasePhi;
374   if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
375     return std::make_pair(nullptr, nullptr);
376 
377   assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
378   unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
379   assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
380          "Expected one operand of phi to be Inc");
381 
382   Builder.SetInsertPoint(GEP);
383 
384   // Replace the vector index with the scalar phi and build a scalar GEP.
385   Ops[*VecOperand] = BasePhi;
386   Type *SourceTy = GEP->getSourceElementType();
387   Value *BasePtr =
388       Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front());
389 
390   // Cast the GEP to an i8*.
391   LLVMContext &Ctx = GEP->getContext();
392   Type *I8PtrTy =
393       Type::getInt8PtrTy(Ctx, GEP->getType()->getPointerAddressSpace());
394   if (BasePtr->getType() != I8PtrTy)
395     BasePtr = Builder.CreatePointerCast(BasePtr, I8PtrTy);
396 
397   // Final adjustments to stride should go in the start block.
398   Builder.SetInsertPoint(
399       BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
400 
401   // Convert stride to pointer size if needed.
402   Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
403   assert(Stride->getType() == IntPtrTy && "Unexpected type");
404 
405   // Scale the stride by the size of the indexed type.
406   if (TypeScale != 1)
407     Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
408 
409   return std::make_pair(BasePtr, Stride);
410 }
411 
412 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
413                                                            Type *DataType,
414                                                            Value *Ptr,
415                                                            Value *AlignOp) {
416   // Make sure the operation will be supported by the backend.
417   if (!isLegalTypeAndAlignment(DataType, AlignOp))
418     return false;
419 
420   // Pointer should be a GEP.
421   auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
422   if (!GEP)
423     return false;
424 
425   IRBuilder<> Builder(GEP);
426 
427   Value *BasePtr, *Stride;
428   std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
429   if (!BasePtr)
430     return false;
431   assert(Stride != nullptr);
432 
433   Builder.SetInsertPoint(II);
434 
435   CallInst *Call;
436   if (II->getIntrinsicID() == Intrinsic::masked_gather)
437     Call = Builder.CreateIntrinsic(
438         Intrinsic::riscv_masked_strided_load,
439         {DataType, BasePtr->getType(), Stride->getType()},
440         {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
441   else
442     Call = Builder.CreateIntrinsic(
443         Intrinsic::riscv_masked_strided_store,
444         {DataType, BasePtr->getType(), Stride->getType()},
445         {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
446 
447   Call->takeName(II);
448   II->replaceAllUsesWith(Call);
449   II->eraseFromParent();
450 
451   if (GEP->use_empty())
452     RecursivelyDeleteTriviallyDeadInstructions(GEP);
453 
454   return true;
455 }
456 
457 bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
458   if (skipFunction(F))
459     return false;
460 
461   auto &TPC = getAnalysis<TargetPassConfig>();
462   auto &TM = TPC.getTM<RISCVTargetMachine>();
463   ST = &TM.getSubtarget<RISCVSubtarget>(F);
464   if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
465     return false;
466 
467   TLI = ST->getTargetLowering();
468   DL = &F.getParent()->getDataLayout();
469   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
470 
471   SmallVector<IntrinsicInst *, 4> Gathers;
472   SmallVector<IntrinsicInst *, 4> Scatters;
473 
474   bool Changed = false;
475 
476   for (BasicBlock &BB : F) {
477     for (Instruction &I : BB) {
478       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
479       if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
480           isa<FixedVectorType>(II->getType())) {
481         Gathers.push_back(II);
482       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
483                  isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
484         Scatters.push_back(II);
485       }
486     }
487   }
488 
489   // Rewrite gather/scatter to form strided load/store if possible.
490   for (auto *II : Gathers)
491     Changed |= tryCreateStridedLoadStore(
492         II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
493   for (auto *II : Scatters)
494     Changed |=
495         tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
496                                   II->getArgOperand(1), II->getArgOperand(2));
497 
498   // Remove any dead phis.
499   while (!MaybeDeadPHIs.empty()) {
500     if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
501       RecursivelyDeleteDeadPHINode(Phi);
502   }
503 
504   return Changed;
505 }
506