//===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// This pass custom lowers llvm.gather and llvm.scatter instructions to /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to /// produce a better final result as we go. // //===----------------------------------------------------------------------===// #include "ARM.h" #include "ARMBaseInstrInfo.h" #include "ARMSubtarget.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/InitializePasses.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsARM.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Transforms/Utils/Local.h" #include #include using namespace llvm; #define DEBUG_TYPE "mve-gather-scatter-lowering" cl::opt EnableMaskedGatherScatters( "enable-arm-maskedgatscat", cl::Hidden, cl::init(false), cl::desc("Enable the generation of masked gathers and scatters")); namespace { class MVEGatherScatterLowering : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid explicit MVEGatherScatterLowering() : FunctionPass(ID) { initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; StringRef getPassName() const override { return "MVE gather/scatter lowering"; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); AU.addRequired(); FunctionPass::getAnalysisUsage(AU); } private: LoopInfo *LI = nullptr; // Check this is a valid gather with correct alignment bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, Align Alignment); // Check whether Ptr is hidden behind a bitcast and look through it void lookThroughBitcast(Value *&Ptr); // Check for a getelementptr and deduce base and offsets from it, on success // returning the base directly and the offsets indirectly using the Offsets // argument Value *checkGEP(Value *&Offsets, Type *Ty, GetElementPtrInst *GEP, IRBuilder<> &Builder); // Compute the scale of this gather/scatter instruction int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); // If the value is a constant, or derived from constants via additions // and multilications, return its numeric value Optional getIfConst(const Value *V); // If Inst is an add instruction, check whether one summand is a // constant. If so, scale this constant and return it together with // the other summand. std::pair getVarAndConst(Value *Inst, int TypeScale); Value *lowerGather(IntrinsicInst *I); // Create a gather from a base + vector of offsets Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder); // Create a gather from a vector of pointers Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment = 0); // Create an incrementing gather from a vector of pointers Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment = 0); Value *lowerScatter(IntrinsicInst *I); // Create a scatter to a base + vector of offsets Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, IRBuilder<> &Builder); // Create a scatter to a vector of pointers Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment = 0); // Create an incrementing scatter from a vector of pointers Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment = 0); // QI gathers and scatters can increment their offsets on their own if // the increment is a constant value (digit) Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr, Value *Ptr, GetElementPtrInst *GEP, IRBuilder<> &Builder); // QI gathers/scatters can increment their offsets on their own if the // increment is a constant value (digit) - this creates a writeback QI // gather/scatter Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, Value *Ptr, unsigned TypeScale, IRBuilder<> &Builder); // Check whether these offsets could be moved out of the loop they're in bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); // Pushes the given add out of the loop void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); // Pushes the given mul out of the loop void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, Value *OffsSecondOperand, unsigned LoopIncrement, IRBuilder<> &Builder); }; } // end anonymous namespace char MVEGatherScatterLowering::ID = 0; INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, "MVE gather/scattering lowering pass", false, false) Pass *llvm::createMVEGatherScatterLoweringPass() { return new MVEGatherScatterLowering(); } bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, Align Alignment) { if (((NumElements == 4 && (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || (NumElements == 16 && ElemSize == 8)) && Alignment >= ElemSize / 8) return true; LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " << "valid alignment or vector type \n"); return false; } Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, GetElementPtrInst *GEP, IRBuilder<> &Builder) { if (!GEP) { LLVM_DEBUG( dbgs() << "masked gathers/scatters: no getelementpointer found\n"); return nullptr; } LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." << " Looking at intrinsic for base + vector of offsets\n"); Value *GEPPtr = GEP->getPointerOperand(); if (GEPPtr->getType()->isVectorTy()) { return nullptr; } if (GEP->getNumOperands() != 2) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" << " operands. Expanding.\n"); return nullptr; } Offsets = GEP->getOperand(1); // Paranoid check whether the number of parallel lanes is the same assert(cast(Ty)->getNumElements() == cast(Offsets->getType())->getNumElements()); // Only offsets can be integrated into an arm gather, any smaller // type would have to be sign extended by the gep - and arm gathers can only // zero extend. Additionally, the offsets do have to originate from a zext of // a vector with element types smaller or equal the type of the gather we're // looking at if (Offsets->getType()->getScalarSizeInBits() != 32) return nullptr; if (ZExtInst *ZextOffs = dyn_cast(Offsets)) Offsets = ZextOffs->getOperand(0); else if (!(cast(Offsets->getType())->getNumElements() == 4 && Offsets->getType()->getScalarSizeInBits() == 32)) return nullptr; if (Ty != Offsets->getType()) { if ((Ty->getScalarSizeInBits() < Offsets->getType()->getScalarSizeInBits())) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: no correct offset type." << " Can't create intrinsic.\n"); return nullptr; } else { Offsets = Builder.CreateZExt( Offsets, VectorType::getInteger(cast(Ty))); } } // If none of the checks failed, return the gep's base pointer LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); return GEPPtr; } void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { // Look through bitcast instruction if #elements is the same if (auto *BitCast = dyn_cast(Ptr)) { auto *BCTy = cast(BitCast->getType()); auto *BCSrcTy = cast(BitCast->getOperand(0)->getType()); if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { LLVM_DEBUG( dbgs() << "masked gathers/scatters: looking through bitcast\n"); Ptr = BitCast->getOperand(0); } } } int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, unsigned MemoryElemSize) { // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, // or a 8bit, 16bit or 32bit load/store scaled by 1 if (GEPElemSize == 32 && MemoryElemSize == 32) return 2; else if (GEPElemSize == 16 && MemoryElemSize == 16) return 1; else if (GEPElemSize == 8) return 0; LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " << "create intrinsic\n"); return -1; } Optional MVEGatherScatterLowering::getIfConst(const Value *V) { const Constant *C = dyn_cast(V); if (C != nullptr) return Optional{C->getUniqueInteger().getSExtValue()}; if (!isa(V)) return Optional{}; const Instruction *I = cast(V); if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Mul) { Optional Op0 = getIfConst(I->getOperand(0)); Optional Op1 = getIfConst(I->getOperand(1)); if (!Op0 || !Op1) return Optional{}; if (I->getOpcode() == Instruction::Add) return Optional{Op0.getValue() + Op1.getValue()}; if (I->getOpcode() == Instruction::Mul) return Optional{Op0.getValue() * Op1.getValue()}; } return Optional{}; } std::pair MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { std::pair ReturnFalse = std::pair(nullptr, 0); // At this point, the instruction we're looking at must be an add or we // bail out Instruction *Add = dyn_cast(Inst); if (Add == nullptr || Add->getOpcode() != Instruction::Add) return ReturnFalse; Value *Summand; Optional Const; // Find out which operand the value that is increased is if ((Const = getIfConst(Add->getOperand(0)))) Summand = Add->getOperand(1); else if ((Const = getIfConst(Add->getOperand(1)))) Summand = Add->getOperand(0); else return ReturnFalse; // Check that the constant is small enough for an incrementing gather int64_t Immediate = Const.getValue() << TypeScale; if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) return ReturnFalse; return std::pair(Summand, Immediate); } Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { using namespace PatternMatch; LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"); // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) // Attempt to turn the masked gather in I into a MVE intrinsic // Potentially optimising the addressing modes as we do so. auto *Ty = cast(I->getType()); Value *Ptr = I->getArgOperand(0); Align Alignment = cast(I->getArgOperand(1))->getAlignValue(); Value *Mask = I->getArgOperand(2); Value *PassThru = I->getArgOperand(3); if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), Alignment)) return nullptr; lookThroughBitcast(Ptr); assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); IRBuilder<> Builder(I->getContext()); Builder.SetInsertPoint(I); Builder.SetCurrentDebugLocation(I->getDebugLoc()); Instruction *Root = I; Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); if (!Load) Load = tryCreateMaskedGatherBase(I, Ptr, Builder); if (!Load) return nullptr; if (!isa(PassThru) && !match(PassThru, m_Zero())) { LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " << "creating select\n"); Load = Builder.CreateSelect(Mask, Load, PassThru); } Root->replaceAllUsesWith(Load); Root->eraseFromParent(); if (Root != I) // If this was an extending gather, we need to get rid of the sext/zext // sext/zext as well as of the gather itself I->eraseFromParent(); LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"); return Load; } Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { using namespace PatternMatch; auto *Ty = cast(I->getType()); LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) // Can't build an intrinsic for this return nullptr; Value *Mask = I->getArgOperand(2); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, {Ty, Ptr->getType()}, {Ptr, Builder.getInt32(Increment)}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_base_predicated, {Ty, Ptr->getType(), Mask->getType()}, {Ptr, Builder.getInt32(Increment), Mask}); } Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { using namespace PatternMatch; auto *Ty = cast(I->getType()); LLVM_DEBUG( dbgs() << "masked gathers: loading from vector of pointers with writeback\n"); if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) // Can't build an intrinsic for this return nullptr; Value *Mask = I->getArgOperand(2); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, {Ty, Ptr->getType()}, {Ptr, Builder.getInt32(Increment)}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_base_wb_predicated, {Ty, Ptr->getType(), Mask->getType()}, {Ptr, Builder.getInt32(Increment), Mask}); } Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { using namespace PatternMatch; Type *OriginalTy = I->getType(); Type *ResultTy = OriginalTy; unsigned Unsigned = 1; // The size of the gather was already checked in isLegalTypeAndAlignment; // if it was not a full vector width an appropriate extend should follow. auto *Extend = Root; if (OriginalTy->getPrimitiveSizeInBits() < 128) { // Only transform gathers with exactly one use if (!I->hasOneUse()) return nullptr; // The correct root to replace is not the CallInst itself, but the // instruction which extends it Extend = cast(*I->users().begin()); if (isa(Extend)) { Unsigned = 0; } else if (!isa(Extend)) { LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. " << "Expanding\n"); return nullptr; } LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n"); ResultTy = Extend->getType(); // The final size of the gather must be a full vector width if (ResultTy->getPrimitiveSizeInBits() != 128) { LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. " << "Expanding\n"); return nullptr; } } GetElementPtrInst *GEP = dyn_cast(Ptr); Value *Offsets; Value *BasePtr = checkGEP(Offsets, ResultTy, GEP, Builder); if (!BasePtr) return nullptr; // Check whether the offset is a constant increment that could be merged into // a QI gather Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); if (Load) return Load; int Scale = computeScale( BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), OriginalTy->getScalarSizeInBits()); if (Scale == -1) return nullptr; Root = Extend; Value *Mask = I->getArgOperand(2); if (!match(Mask, m_One())) return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_offset_predicated, {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_offset, {ResultTy, BasePtr->getType(), Offsets->getType()}, {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); } Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { using namespace PatternMatch; LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"); // @llvm.masked.scatter.*(data, ptrs, alignment, mask) // Attempt to turn the masked scatter in I into a MVE intrinsic // Potentially optimising the addressing modes as we do so. Value *Input = I->getArgOperand(0); Value *Ptr = I->getArgOperand(1); Align Alignment = cast(I->getArgOperand(2))->getAlignValue(); auto *Ty = cast(Input->getType()); if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), Alignment)) return nullptr; lookThroughBitcast(Ptr); assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); IRBuilder<> Builder(I->getContext()); Builder.SetInsertPoint(I); Builder.SetCurrentDebugLocation(I->getDebugLoc()); Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); if (!Store) Store = tryCreateMaskedScatterBase(I, Ptr, Builder); if (!Store) return nullptr; LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"); I->eraseFromParent(); return Store; } Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase( IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { using namespace PatternMatch; Value *Input = I->getArgOperand(0); auto *Ty = cast(Input->getType()); // Only QR variants allow truncating if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { // Can't build an intrinsic for this return nullptr; } Value *Mask = I->getArgOperand(3); // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, {Ptr->getType(), Input->getType()}, {Ptr, Builder.getInt32(Increment), Input}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vstr_scatter_base_predicated, {Ptr->getType(), Input->getType(), Mask->getType()}, {Ptr, Builder.getInt32(Increment), Input, Mask}); } Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { using namespace PatternMatch; Value *Input = I->getArgOperand(0); auto *Ty = cast(Input->getType()); LLVM_DEBUG( dbgs() << "masked scatters: storing to a vector of pointers with writeback\n"); if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) // Can't build an intrinsic for this return nullptr; Value *Mask = I->getArgOperand(3); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, {Ptr->getType(), Input->getType()}, {Ptr, Builder.getInt32(Increment), Input}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, {Ptr->getType(), Input->getType(), Mask->getType()}, {Ptr, Builder.getInt32(Increment), Input, Mask}); } Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { using namespace PatternMatch; Value *Input = I->getArgOperand(0); Value *Mask = I->getArgOperand(3); Type *InputTy = Input->getType(); Type *MemoryTy = InputTy; LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" << " to base + vector of offsets\n"); // If the input has been truncated, try to integrate that trunc into the // scatter instruction (we don't care about alignment here) if (TruncInst *Trunc = dyn_cast(Input)) { Value *PreTrunc = Trunc->getOperand(0); Type *PreTruncTy = PreTrunc->getType(); if (PreTruncTy->getPrimitiveSizeInBits() == 128) { Input = PreTrunc; InputTy = PreTruncTy; } } if (InputTy->getPrimitiveSizeInBits() != 128) { LLVM_DEBUG( dbgs() << "masked scatters: cannot create scatters for non-standard" << " input types. Expanding.\n"); return nullptr; } GetElementPtrInst *GEP = dyn_cast(Ptr); Value *Offsets; Value *BasePtr = checkGEP(Offsets, InputTy, GEP, Builder); if (!BasePtr) return nullptr; // Check whether the offset is a constant increment that could be merged into // a QI gather Value *Store = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); if (Store) return Store; int Scale = computeScale( BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), MemoryTy->getScalarSizeInBits()); if (Scale == -1) return nullptr; if (!match(Mask, m_One())) return Builder.CreateIntrinsic( Intrinsic::arm_mve_vstr_scatter_offset_predicated, {BasePtr->getType(), Offsets->getType(), Input->getType(), Mask->getType()}, {BasePtr, Offsets, Input, Builder.getInt32(MemoryTy->getScalarSizeInBits()), Builder.getInt32(Scale), Mask}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vstr_scatter_offset, {BasePtr->getType(), Offsets->getType(), Input->getType()}, {BasePtr, Offsets, Input, Builder.getInt32(MemoryTy->getScalarSizeInBits()), Builder.getInt32(Scale)}); } Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat( IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP, IRBuilder<> &Builder) { FixedVectorType *Ty; if (I->getIntrinsicID() == Intrinsic::masked_gather) Ty = cast(I->getType()); else Ty = cast(I->getArgOperand(0)->getType()); // Incrementing gathers only exist for v4i32 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) return nullptr; Loop *L = LI->getLoopFor(I->getParent()); if (L == nullptr) // Incrementing gathers are not beneficial outside of a loop return nullptr; LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " "wb gather/scatter\n"); // The gep was in charge of making sure the offsets are scaled correctly // - calculate that factor so it can be applied by hand DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout(); int TypeScale = computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()), DT.getTypeSizeInBits(GEP->getType()) / cast(GEP->getType())->getNumElements()); if (TypeScale == -1) return nullptr; if (GEP->hasOneUse()) { // Only in this case do we want to build a wb gather, because the wb will // change the phi which does affect other users of the gep (which will still // be using the phi in the old way) Value *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder); if (Load != nullptr) return Load; } LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " "non-wb gather/scatter\n"); std::pair Add = getVarAndConst(Offsets, TypeScale); if (Add.first == nullptr) return nullptr; Value *OffsetsIncoming = Add.first; int64_t Immediate = Add.second; // Make sure the offsets are scaled correctly Instruction *ScaledOffsets = BinaryOperator::Create( Instruction::Shl, OffsetsIncoming, Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)), "ScaledIndex", I); // Add the base to the offsets OffsetsIncoming = BinaryOperator::Create( Instruction::Add, ScaledOffsets, Builder.CreateVectorSplat( Ty->getNumElements(), Builder.CreatePtrToInt( BasePtr, cast(ScaledOffsets->getType())->getElementType())), "StartIndex", I); if (I->getIntrinsicID() == Intrinsic::masked_gather) return cast( tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate)); else return cast( tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate)); } Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, IRBuilder<> &Builder) { // Check whether this gather's offset is incremented by a constant - if so, // and the load is of the right type, we can merge this into a QI gather Loop *L = LI->getLoopFor(I->getParent()); // Offsets that are worth merging into this instruction will be incremented // by a constant, thus we're looking for an add of a phi and a constant PHINode *Phi = dyn_cast(Offsets); if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) // No phi means no IV to write back to; if there is a phi, we expect it // to have exactly two incoming values; the only phis we are interested in // will be loop IV's and have exactly two uses, one in their increment and // one in the gather's gep return nullptr; unsigned IncrementIndex = Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; // Look through the phi to the phi increment Offsets = Phi->getIncomingValue(IncrementIndex); std::pair Add = getVarAndConst(Offsets, TypeScale); if (Add.first == nullptr) return nullptr; Value *OffsetsIncoming = Add.first; int64_t Immediate = Add.second; if (OffsetsIncoming != Phi) // Then the increment we are looking at is not an increment of the // induction variable, and we don't want to do a writeback return nullptr; Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); unsigned NumElems = cast(OffsetsIncoming->getType())->getNumElements(); // Make sure the offsets are scaled correctly Instruction *ScaledOffsets = BinaryOperator::Create( Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); // Add the base to the offsets OffsetsIncoming = BinaryOperator::Create( Instruction::Add, ScaledOffsets, Builder.CreateVectorSplat( NumElems, Builder.CreatePtrToInt( BasePtr, cast(ScaledOffsets->getType())->getElementType())), "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); // The gather is pre-incrementing OffsetsIncoming = BinaryOperator::Create( Instruction::Sub, OffsetsIncoming, Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), "PreIncrementStartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); Builder.SetInsertPoint(I); Value *EndResult; Value *NewInduction; if (I->getIntrinsicID() == Intrinsic::masked_gather) { // Build the incrementing gather Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); // One value to be handed to whoever uses the gather, one is the loop // increment EndResult = Builder.CreateExtractValue(Load, 0, "Gather"); NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement"); } else { // Build the incrementing scatter NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); EndResult = NewInduction; } Instruction *AddInst = cast(Offsets); AddInst->replaceAllUsesWith(NewInduction); AddInst->eraseFromParent(); Phi->setIncomingValue(IncrementIndex, NewInduction); return EndResult; } void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); Instruction *InsertionPoint = &cast(Phi->getIncomingBlock(StartIndex)->back()); // Initialize the phi with a vector that contains a sum of the constants Instruction *NewIndex = BinaryOperator::Create( Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, "PushedOutAdd", InsertionPoint); unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; // Order such that start index comes first (this reduces mov's) Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), Phi->getIncomingBlock(IncrementIndex)); Phi->removeIncomingValue(IncrementIndex); Phi->removeIncomingValue(StartIndex); } void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, Value *IncrementPerRound, Value *OffsSecondOperand, unsigned LoopIncrement, IRBuilder<> &Builder) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); // Create a new scalar add outside of the loop and transform it to a splat // by which loop variable can be incremented Instruction *InsertionPoint = &cast( Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); // Create a new index Value *StartIndex = BinaryOperator::Create( Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), OffsSecondOperand, "PushedOutMul", InsertionPoint); Instruction *Product = BinaryOperator::Create(Instruction::Mul, IncrementPerRound, OffsSecondOperand, "Product", InsertionPoint); // Increment NewIndex by Product instead of the multiplication Instruction *NewIncrement = BinaryOperator::Create( Instruction::Add, Phi, Product, "IncrementPushedOutMul", cast(Phi->getIncomingBlock(LoopIncrement)->back()) .getPrevNode()); Phi->addIncoming(StartIndex, Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); Phi->removeIncomingValue((unsigned)0); Phi->removeIncomingValue((unsigned)0); return; } // Check whether all usages of this instruction are as offsets of // gathers/scatters or simple arithmetics only used by gathers/scatters static bool hasAllGatScatUsers(Instruction *I) { if (I->hasNUses(0)) { return false; } bool Gatscat = true; for (User *U : I->users()) { if (!isa(U)) return false; if (isa(U) || isGatherScatter(dyn_cast(U))) { return Gatscat; } else { unsigned OpCode = cast(U)->getOpcode(); if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && hasAllGatScatUsers(cast(U))) { continue; } return false; } } return Gatscat; } bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"); // Optimise the addresses of gathers/scatters by moving invariant // calculations out of the loop if (!isa(Offsets)) return false; Instruction *Offs = cast(Offsets); if (Offs->getOpcode() != Instruction::Add && Offs->getOpcode() != Instruction::Mul) return false; Loop *L = LI->getLoopFor(BB); if (L == nullptr) return false; if (!Offs->hasOneUse()) { if (!hasAllGatScatUsers(Offs)) return false; } // Find out which, if any, operand of the instruction // is a phi node PHINode *Phi; int OffsSecondOp; if (isa(Offs->getOperand(0))) { Phi = cast(Offs->getOperand(0)); OffsSecondOp = 1; } else if (isa(Offs->getOperand(1))) { Phi = cast(Offs->getOperand(1)); OffsSecondOp = 0; } else { bool Changed = true; if (isa(Offs->getOperand(0)) && L->contains(cast(Offs->getOperand(0)))) Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); if (isa(Offs->getOperand(1)) && L->contains(cast(Offs->getOperand(1)))) Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); if (!Changed) { return false; } else { if (isa(Offs->getOperand(0))) { Phi = cast(Offs->getOperand(0)); OffsSecondOp = 1; } else if (isa(Offs->getOperand(1))) { Phi = cast(Offs->getOperand(1)); OffsSecondOp = 0; } else { return false; } } } // A phi node we want to perform this function on should be from the // loop header, and shouldn't have more than 2 incoming values if (Phi->getParent() != L->getHeader() || Phi->getNumIncomingValues() != 2) return false; // The phi must be an induction variable Instruction *Op; int IncrementingBlock = -1; for (int i = 0; i < 2; i++) if ((Op = dyn_cast(Phi->getIncomingValue(i))) != nullptr) if (Op->getOpcode() == Instruction::Add && (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi)) IncrementingBlock = i; if (IncrementingBlock == -1) return false; Instruction *IncInstruction = cast(Phi->getIncomingValue(IncrementingBlock)); // If the phi is not used by anything else, we can just adapt it when // replacing the instruction; if it is, we'll have to duplicate it PHINode *NewPhi; Value *IncrementPerRound = IncInstruction->getOperand( (IncInstruction->getOperand(0) == Phi) ? 1 : 0); // Get the value that is added to/multiplied with the phi Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); if (IncrementPerRound->getType() != OffsSecondOperand->getType()) // Something has gone wrong, abort return false; // Only proceed if the increment per round is a constant or an instruction // which does not originate from within the loop if (!isa(IncrementPerRound) && !(isa(IncrementPerRound) && !L->contains(cast(IncrementPerRound)))) return false; if (Phi->getNumUses() == 2) { // No other users -> reuse existing phi (One user is the instruction // we're looking at, the other is the phi increment) if (IncInstruction->getNumUses() != 1) { // If the incrementing instruction does have more users than // our phi, we need to copy it IncInstruction = BinaryOperator::Create( Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, IncrementPerRound, "LoopIncrement", IncInstruction); Phi->setIncomingValue(IncrementingBlock, IncInstruction); } NewPhi = Phi; } else { // There are other users -> create a new phi NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi); std::vector Increases; // Copy the incoming values of the old phi NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); IncInstruction = BinaryOperator::Create( Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, IncrementPerRound, "LoopIncrement", IncInstruction); NewPhi->addIncoming(IncInstruction, Phi->getIncomingBlock(IncrementingBlock)); IncrementingBlock = 1; } IRBuilder<> Builder(BB->getContext()); Builder.SetInsertPoint(Phi); Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); switch (Offs->getOpcode()) { case Instruction::Add: pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); break; case Instruction::Mul: pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, Builder); break; default: return false; } LLVM_DEBUG( dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n"); // The instruction has now been "absorbed" into the phi value Offs->replaceAllUsesWith(NewPhi); if (Offs->hasNUses(0)) Offs->eraseFromParent(); // Clean up the old increment in case it's unused because we built a new // one if (IncInstruction->hasNUses(0)) IncInstruction->eraseFromParent(); return true; } bool MVEGatherScatterLowering::runOnFunction(Function &F) { if (!EnableMaskedGatherScatters) return false; auto &TPC = getAnalysis(); auto &TM = TPC.getTM(); auto *ST = &TM.getSubtarget(F); if (!ST->hasMVEIntegerOps()) return false; LI = &getAnalysis().getLoopInfo(); SmallVector Gathers; SmallVector Scatters; bool Changed = false; for (BasicBlock &BB : F) { for (Instruction &I : BB) { IntrinsicInst *II = dyn_cast(&I); if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { Gathers.push_back(II); if (isa(II->getArgOperand(0))) Changed |= optimiseOffsets( cast(II->getArgOperand(0))->getOperand(1), II->getParent(), LI); } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { Scatters.push_back(II); if (isa(II->getArgOperand(1))) Changed |= optimiseOffsets( cast(II->getArgOperand(1))->getOperand(1), II->getParent(), LI); } } } for (unsigned i = 0; i < Gathers.size(); i++) { IntrinsicInst *I = Gathers[i]; Value *L = lowerGather(I); if (L == nullptr) continue; // Get rid of any now dead instructions SimplifyInstructionsInBlock(cast(L)->getParent()); Changed = true; } for (unsigned i = 0; i < Scatters.size(); i++) { IntrinsicInst *I = Scatters[i]; Value *S = lowerScatter(I); if (S == nullptr) continue; // Get rid of any now dead instructions SimplifyInstructionsInBlock(cast(S)->getParent()); Changed = true; } return Changed; }