1 //===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Performs general IR level optimizations on SVE intrinsics.
11 //
12 // This pass performs the following optimizations:
13 //
14 // - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g:
15 //     %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
16 //     %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31)
17 //     ; (%1 can be replaced with a reinterpret of %2)
18 //
19 // - optimizes ptest intrinsics where the operands are being needlessly
20 //   converted to and from svbool_t.
21 //
22 //===----------------------------------------------------------------------===//
23 
24 #include "AArch64.h"
25 #include "Utils/AArch64BaseInfo.h"
26 #include "llvm/ADT/PostOrderIterator.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/Dominators.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/IntrinsicsAArch64.h"
34 #include "llvm/IR/LLVMContext.h"
35 #include "llvm/IR/PatternMatch.h"
36 #include "llvm/InitializePasses.h"
37 #include "llvm/Support/Debug.h"
38 
39 using namespace llvm;
40 using namespace llvm::PatternMatch;
41 
42 #define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
43 
44 namespace llvm {
45 void initializeSVEIntrinsicOptsPass(PassRegistry &);
46 }
47 
48 namespace {
49 struct SVEIntrinsicOpts : public ModulePass {
50   static char ID; // Pass identification, replacement for typeid
SVEIntrinsicOpts__anonf9227ce60111::SVEIntrinsicOpts51   SVEIntrinsicOpts() : ModulePass(ID) {
52     initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry());
53   }
54 
55   bool runOnModule(Module &M) override;
56   void getAnalysisUsage(AnalysisUsage &AU) const override;
57 
58 private:
59   bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
60                                    SmallSetVector<IntrinsicInst *, 4> &PTrues);
61   bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
62 
63   /// Operates at the instruction-scope. I.e., optimizations are applied local
64   /// to individual instructions.
65   static bool optimizeIntrinsic(Instruction *I);
66   bool optimizeIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
67 
68   /// Operates at the function-scope. I.e., optimizations are applied local to
69   /// the functions themselves.
70   bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);
71 
72   static bool optimizePTest(IntrinsicInst *I);
73   static bool optimizeVectorMul(IntrinsicInst *I);
74   static bool optimizeTBL(IntrinsicInst *I);
75 };
76 } // end anonymous namespace
77 
getAnalysisUsage(AnalysisUsage & AU) const78 void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const {
79   AU.addRequired<DominatorTreeWrapperPass>();
80   AU.setPreservesCFG();
81 }
82 
83 char SVEIntrinsicOpts::ID = 0;
84 static const char *name = "SVE intrinsics optimizations";
85 INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
86 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
INITIALIZE_PASS_END(SVEIntrinsicOpts,DEBUG_TYPE,name,false,false)87 INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
88 
89 ModulePass *llvm::createSVEIntrinsicOptsPass() {
90   return new SVEIntrinsicOpts();
91 }
92 
93 /// Checks if a ptrue intrinsic call is promoted. The act of promoting a
94 /// ptrue will introduce zeroing. For example:
95 ///
96 ///     %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
97 ///     %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1)
98 ///     %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
99 ///
100 /// %1 is promoted, because it is converted:
101 ///
102 ///     <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1>
103 ///
104 /// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool.
isPTruePromoted(IntrinsicInst * PTrue)105 static bool isPTruePromoted(IntrinsicInst *PTrue) {
106   // Find all users of this intrinsic that are calls to convert-to-svbool
107   // reinterpret intrinsics.
108   SmallVector<IntrinsicInst *, 4> ConvertToUses;
109   for (User *User : PTrue->users()) {
110     if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
111       ConvertToUses.push_back(cast<IntrinsicInst>(User));
112     }
113   }
114 
115   // If no such calls were found, this is ptrue is not promoted.
116   if (ConvertToUses.empty())
117     return false;
118 
119   // Otherwise, try to find users of the convert-to-svbool intrinsics that are
120   // calls to the convert-from-svbool intrinsic, and would result in some lanes
121   // being zeroed.
122   const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType());
123   for (IntrinsicInst *ConvertToUse : ConvertToUses) {
124     for (User *User : ConvertToUse->users()) {
125       auto *IntrUser = dyn_cast<IntrinsicInst>(User);
126       if (IntrUser && IntrUser->getIntrinsicID() ==
127                           Intrinsic::aarch64_sve_convert_from_svbool) {
128         const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
129 
130         // Would some lanes become zeroed by the conversion?
131         if (IntrUserVTy->getElementCount().getKnownMinValue() >
132             PTrueVTy->getElementCount().getKnownMinValue())
133           // This is a promoted ptrue.
134           return true;
135       }
136     }
137   }
138 
139   // If no matching calls were found, this is not a promoted ptrue.
140   return false;
141 }
142 
143 /// Attempts to coalesce ptrues in a basic block.
coalescePTrueIntrinsicCalls(BasicBlock & BB,SmallSetVector<IntrinsicInst *,4> & PTrues)144 bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
145     BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) {
146   if (PTrues.size() <= 1)
147     return false;
148 
149   // Find the ptrue with the most lanes.
150   auto *MostEncompassingPTrue = *std::max_element(
151       PTrues.begin(), PTrues.end(), [](auto *PTrue1, auto *PTrue2) {
152         auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
153         auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
154         return PTrue1VTy->getElementCount().getKnownMinValue() <
155                PTrue2VTy->getElementCount().getKnownMinValue();
156       });
157 
158   // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving
159   // behind only the ptrues to be coalesced.
160   PTrues.remove(MostEncompassingPTrue);
161   PTrues.remove_if([](auto *PTrue) { return isPTruePromoted(PTrue); });
162 
163   // Hoist MostEncompassingPTrue to the start of the basic block. It is always
164   // safe to do this, since ptrue intrinsic calls are guaranteed to have no
165   // predecessors.
166   MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt());
167 
168   LLVMContext &Ctx = BB.getContext();
169   IRBuilder<> Builder(Ctx);
170   Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
171 
172   auto *MostEncompassingPTrueVTy =
173       cast<VectorType>(MostEncompassingPTrue->getType());
174   auto *ConvertToSVBool = Builder.CreateIntrinsic(
175       Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
176       {MostEncompassingPTrue});
177 
178   bool ConvertFromCreated = false;
179   for (auto *PTrue : PTrues) {
180     auto *PTrueVTy = cast<VectorType>(PTrue->getType());
181 
182     // Only create the converts if the types are not already the same, otherwise
183     // just use the most encompassing ptrue.
184     if (MostEncompassingPTrueVTy != PTrueVTy) {
185       ConvertFromCreated = true;
186 
187       Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
188       auto *ConvertFromSVBool =
189           Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
190                                   {PTrueVTy}, {ConvertToSVBool});
191       PTrue->replaceAllUsesWith(ConvertFromSVBool);
192     } else
193       PTrue->replaceAllUsesWith(MostEncompassingPTrue);
194 
195     PTrue->eraseFromParent();
196   }
197 
198   // We never used the ConvertTo so remove it
199   if (!ConvertFromCreated)
200     ConvertToSVBool->eraseFromParent();
201 
202   return true;
203 }
204 
205 /// The goal of this function is to remove redundant calls to the SVE ptrue
206 /// intrinsic in each basic block within the given functions.
207 ///
208 /// SVE ptrues have two representations in LLVM IR:
209 /// - a logical representation -- an arbitrary-width scalable vector of i1s,
210 ///   i.e. <vscale x N x i1>.
211 /// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element
212 ///   scalable vector of i1s, i.e. <vscale x 16 x i1>.
213 ///
214 /// The SVE ptrue intrinsic is used to create a logical representation of an SVE
215 /// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If
216 /// P1 creates a logical SVE predicate that is at least as wide as the logical
217 /// SVE predicate created by P2, then all of the bits that are true in the
218 /// physical representation of P2 are necessarily also true in the physical
219 /// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to
220 /// P2 is redundant and can be replaced by an SVE reinterpret of P1 via
221 /// convert.{to,from}.svbool.
222 ///
223 /// Currently, this pass only coalesces calls to SVE ptrue intrinsics
224 /// if they match the following conditions:
225 ///
226 /// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns.
227 ///   SV_ALL indicates that all bits of the predicate vector are to be set to
228 ///   true. SV_POW2 indicates that all bits of the predicate vector up to the
229 ///   largest power-of-two are to be set to true.
230 /// - the result of the call to the intrinsic is not promoted to a wider
231 ///   predicate. In this case, keeping the extra ptrue leads to better codegen
232 ///   -- coalescing here would create an irreducible chain of SVE reinterprets
233 ///   via convert.{to,from}.svbool.
234 ///
235 /// EXAMPLE:
236 ///
237 ///     %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL)
238 ///     ; Logical:  <1, 1, 1, 1, 1, 1, 1, 1>
239 ///     ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0>
240 ///     ...
241 ///
242 ///     %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL)
243 ///     ; Logical:  <1, 1, 1, 1>
244 ///     ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0>
245 ///     ...
246 ///
247 /// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance:
248 ///
249 ///     %1 = <vscale x 8 x i1> ptrue(i32 i31)
250 ///     %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1)
251 ///     %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2)
252 ///
optimizePTrueIntrinsicCalls(SmallSetVector<Function *,4> & Functions)253 bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
254     SmallSetVector<Function *, 4> &Functions) {
255   bool Changed = false;
256 
257   for (auto *F : Functions) {
258     for (auto &BB : *F) {
259       SmallSetVector<IntrinsicInst *, 4> SVAllPTrues;
260       SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues;
261 
262       // For each basic block, collect the used ptrues and try to coalesce them.
263       for (Instruction &I : BB) {
264         if (I.use_empty())
265           continue;
266 
267         auto *IntrI = dyn_cast<IntrinsicInst>(&I);
268         if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
269           continue;
270 
271         const auto PTruePattern =
272             cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
273 
274         if (PTruePattern == AArch64SVEPredPattern::all)
275           SVAllPTrues.insert(IntrI);
276         if (PTruePattern == AArch64SVEPredPattern::pow2)
277           SVPow2PTrues.insert(IntrI);
278       }
279 
280       Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
281       Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
282     }
283   }
284 
285   return Changed;
286 }
287 
optimizePTest(IntrinsicInst * I)288 bool SVEIntrinsicOpts::optimizePTest(IntrinsicInst *I) {
289   IntrinsicInst *Op1 = dyn_cast<IntrinsicInst>(I->getArgOperand(0));
290   IntrinsicInst *Op2 = dyn_cast<IntrinsicInst>(I->getArgOperand(1));
291 
292   if (Op1 && Op2 &&
293       Op1->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
294       Op2->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
295       Op1->getArgOperand(0)->getType() == Op2->getArgOperand(0)->getType()) {
296 
297     Value *Ops[] = {Op1->getArgOperand(0), Op2->getArgOperand(0)};
298     Type *Tys[] = {Op1->getArgOperand(0)->getType()};
299     Module *M = I->getParent()->getParent()->getParent();
300 
301     auto Fn = Intrinsic::getDeclaration(M, I->getIntrinsicID(), Tys);
302     auto CI = CallInst::Create(Fn, Ops, I->getName(), I);
303 
304     I->replaceAllUsesWith(CI);
305     I->eraseFromParent();
306     if (Op1->use_empty())
307       Op1->eraseFromParent();
308     if (Op1 != Op2 && Op2->use_empty())
309       Op2->eraseFromParent();
310 
311     return true;
312   }
313 
314   return false;
315 }
316 
optimizeVectorMul(IntrinsicInst * I)317 bool SVEIntrinsicOpts::optimizeVectorMul(IntrinsicInst *I) {
318   assert((I->getIntrinsicID() == Intrinsic::aarch64_sve_mul ||
319           I->getIntrinsicID() == Intrinsic::aarch64_sve_fmul) &&
320          "Unexpected opcode");
321 
322   auto *OpPredicate = I->getOperand(0);
323   auto *OpMultiplicand = I->getOperand(1);
324   auto *OpMultiplier = I->getOperand(2);
325 
326   // Return true if a given instruction is an aarch64_sve_dup_x intrinsic call
327   // with a unit splat value, false otherwise.
328   auto IsUnitDupX = [](auto *I) {
329     auto *IntrI = dyn_cast<IntrinsicInst>(I);
330     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
331       return false;
332 
333     auto *SplatValue = IntrI->getOperand(0);
334     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
335   };
336 
337   // Return true if a given instruction is an aarch64_sve_dup intrinsic call
338   // with a unit splat value, false otherwise.
339   auto IsUnitDup = [](auto *I) {
340     auto *IntrI = dyn_cast<IntrinsicInst>(I);
341     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
342       return false;
343 
344     auto *SplatValue = IntrI->getOperand(2);
345     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
346   };
347 
348   bool Changed = true;
349 
350   // The OpMultiplier variable should always point to the dup (if any), so
351   // swap if necessary.
352   if (IsUnitDup(OpMultiplicand) || IsUnitDupX(OpMultiplicand))
353     std::swap(OpMultiplier, OpMultiplicand);
354 
355   if (IsUnitDupX(OpMultiplier)) {
356     // [f]mul pg (dupx 1) %n => %n
357     I->replaceAllUsesWith(OpMultiplicand);
358     I->eraseFromParent();
359     Changed = true;
360   } else if (IsUnitDup(OpMultiplier)) {
361     // [f]mul pg (dup pg 1) %n => %n
362     auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
363     auto *DupPg = DupInst->getOperand(1);
364     // TODO: this is naive. The optimization is still valid if DupPg
365     // 'encompasses' OpPredicate, not only if they're the same predicate.
366     if (OpPredicate == DupPg) {
367       I->replaceAllUsesWith(OpMultiplicand);
368       I->eraseFromParent();
369       Changed = true;
370     }
371   }
372 
373   // If an instruction was optimized out then it is possible that some dangling
374   // instructions are left.
375   if (Changed) {
376     auto *OpPredicateInst = dyn_cast<Instruction>(OpPredicate);
377     auto *OpMultiplierInst = dyn_cast<Instruction>(OpMultiplier);
378     if (OpMultiplierInst && OpMultiplierInst->use_empty())
379       OpMultiplierInst->eraseFromParent();
380     if (OpPredicateInst && OpPredicateInst->use_empty())
381       OpPredicateInst->eraseFromParent();
382   }
383 
384   return Changed;
385 }
386 
optimizeTBL(IntrinsicInst * I)387 bool SVEIntrinsicOpts::optimizeTBL(IntrinsicInst *I) {
388   assert(I->getIntrinsicID() == Intrinsic::aarch64_sve_tbl &&
389          "Unexpected opcode");
390 
391   auto *OpVal = I->getOperand(0);
392   auto *OpIndices = I->getOperand(1);
393   VectorType *VTy = cast<VectorType>(I->getType());
394 
395   // Check whether OpIndices is an aarch64_sve_dup_x intrinsic call with
396   // constant splat value < minimal element count of result.
397   auto *DupXIntrI = dyn_cast<IntrinsicInst>(OpIndices);
398   if (!DupXIntrI || DupXIntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
399     return false;
400 
401   auto *SplatValue = dyn_cast<ConstantInt>(DupXIntrI->getOperand(0));
402   if (!SplatValue ||
403       SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
404     return false;
405 
406   // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
407   // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
408   LLVMContext &Ctx = I->getContext();
409   IRBuilder<> Builder(Ctx);
410   Builder.SetInsertPoint(I);
411   auto *Extract = Builder.CreateExtractElement(OpVal, SplatValue);
412   auto *VectorSplat =
413       Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
414 
415   I->replaceAllUsesWith(VectorSplat);
416   I->eraseFromParent();
417   if (DupXIntrI->use_empty())
418     DupXIntrI->eraseFromParent();
419   return true;
420 }
421 
optimizeIntrinsic(Instruction * I)422 bool SVEIntrinsicOpts::optimizeIntrinsic(Instruction *I) {
423   IntrinsicInst *IntrI = dyn_cast<IntrinsicInst>(I);
424   if (!IntrI)
425     return false;
426 
427   switch (IntrI->getIntrinsicID()) {
428   case Intrinsic::aarch64_sve_fmul:
429   case Intrinsic::aarch64_sve_mul:
430     return optimizeVectorMul(IntrI);
431   case Intrinsic::aarch64_sve_ptest_any:
432   case Intrinsic::aarch64_sve_ptest_first:
433   case Intrinsic::aarch64_sve_ptest_last:
434     return optimizePTest(IntrI);
435   case Intrinsic::aarch64_sve_tbl:
436     return optimizeTBL(IntrI);
437   default:
438     return false;
439   }
440 
441   return true;
442 }
443 
optimizeIntrinsicCalls(SmallSetVector<Function *,4> & Functions)444 bool SVEIntrinsicOpts::optimizeIntrinsicCalls(
445     SmallSetVector<Function *, 4> &Functions) {
446   bool Changed = false;
447   for (auto *F : Functions) {
448     DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
449 
450     // Traverse the DT with an rpo walk so we see defs before uses, allowing
451     // simplification to be done incrementally.
452     BasicBlock *Root = DT->getRoot();
453     ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
454     for (auto *BB : RPOT)
455       for (Instruction &I : make_early_inc_range(*BB))
456         Changed |= optimizeIntrinsic(&I);
457   }
458   return Changed;
459 }
460 
optimizeFunctions(SmallSetVector<Function *,4> & Functions)461 bool SVEIntrinsicOpts::optimizeFunctions(
462     SmallSetVector<Function *, 4> &Functions) {
463   bool Changed = false;
464 
465   Changed |= optimizePTrueIntrinsicCalls(Functions);
466   Changed |= optimizeIntrinsicCalls(Functions);
467 
468   return Changed;
469 }
470 
runOnModule(Module & M)471 bool SVEIntrinsicOpts::runOnModule(Module &M) {
472   bool Changed = false;
473   SmallSetVector<Function *, 4> Functions;
474 
475   // Check for SVE intrinsic declarations first so that we only iterate over
476   // relevant functions. Where an appropriate declaration is found, store the
477   // function(s) where it is used so we can target these only.
478   for (auto &F : M.getFunctionList()) {
479     if (!F.isDeclaration())
480       continue;
481 
482     switch (F.getIntrinsicID()) {
483     case Intrinsic::aarch64_sve_ptest_any:
484     case Intrinsic::aarch64_sve_ptest_first:
485     case Intrinsic::aarch64_sve_ptest_last:
486     case Intrinsic::aarch64_sve_ptrue:
487     case Intrinsic::aarch64_sve_mul:
488     case Intrinsic::aarch64_sve_fmul:
489     case Intrinsic::aarch64_sve_tbl:
490       for (User *U : F.users())
491         Functions.insert(cast<Instruction>(U)->getFunction());
492       break;
493     default:
494       break;
495     }
496   }
497 
498   if (!Functions.empty())
499     Changed |= optimizeFunctions(Functions);
500 
501   return Changed;
502 }
503