1 //===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Performs general IR level optimizations on SVE intrinsics.
10 //
11 // This pass performs the following optimizations:
12 //
13 // - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g:
14 //     %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
15 //     %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31)
16 //     ; (%1 can be replaced with a reinterpret of %2)
17 //
18 // - optimizes ptest intrinsics where the operands are being needlessly
19 //   converted to and from svbool_t.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "AArch64.h"
24 #include "Utils/AArch64BaseInfo.h"
25 #include "llvm/ADT/PostOrderIterator.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/IntrinsicsAArch64.h"
33 #include "llvm/IR/LLVMContext.h"
34 #include "llvm/IR/PatternMatch.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Support/Debug.h"
37 #include <optional>
38 
39 using namespace llvm;
40 using namespace llvm::PatternMatch;
41 
42 #define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
43 
44 namespace {
45 struct SVEIntrinsicOpts : public ModulePass {
46   static char ID; // Pass identification, replacement for typeid
47   SVEIntrinsicOpts() : ModulePass(ID) {
48     initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry());
49   }
50 
51   bool runOnModule(Module &M) override;
52   void getAnalysisUsage(AnalysisUsage &AU) const override;
53 
54 private:
55   bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
56                                    SmallSetVector<IntrinsicInst *, 4> &PTrues);
57   bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
58   bool optimizePredicateStore(Instruction *I);
59   bool optimizePredicateLoad(Instruction *I);
60 
61   bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions);
62 
63   /// Operates at the function-scope. I.e., optimizations are applied local to
64   /// the functions themselves.
65   bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);
66 };
67 } // end anonymous namespace
68 
69 void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const {
70   AU.addRequired<DominatorTreeWrapperPass>();
71   AU.setPreservesCFG();
72 }
73 
74 char SVEIntrinsicOpts::ID = 0;
75 static const char *name = "SVE intrinsics optimizations";
76 INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
77 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
78 INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
79 
80 ModulePass *llvm::createSVEIntrinsicOptsPass() {
81   return new SVEIntrinsicOpts();
82 }
83 
84 /// Checks if a ptrue intrinsic call is promoted. The act of promoting a
85 /// ptrue will introduce zeroing. For example:
86 ///
87 ///     %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
88 ///     %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1)
89 ///     %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
90 ///
91 /// %1 is promoted, because it is converted:
92 ///
93 ///     <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1>
94 ///
95 /// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool.
96 static bool isPTruePromoted(IntrinsicInst *PTrue) {
97   // Find all users of this intrinsic that are calls to convert-to-svbool
98   // reinterpret intrinsics.
99   SmallVector<IntrinsicInst *, 4> ConvertToUses;
100   for (User *User : PTrue->users()) {
101     if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
102       ConvertToUses.push_back(cast<IntrinsicInst>(User));
103     }
104   }
105 
106   // If no such calls were found, this is ptrue is not promoted.
107   if (ConvertToUses.empty())
108     return false;
109 
110   // Otherwise, try to find users of the convert-to-svbool intrinsics that are
111   // calls to the convert-from-svbool intrinsic, and would result in some lanes
112   // being zeroed.
113   const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType());
114   for (IntrinsicInst *ConvertToUse : ConvertToUses) {
115     for (User *User : ConvertToUse->users()) {
116       auto *IntrUser = dyn_cast<IntrinsicInst>(User);
117       if (IntrUser && IntrUser->getIntrinsicID() ==
118                           Intrinsic::aarch64_sve_convert_from_svbool) {
119         const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
120 
121         // Would some lanes become zeroed by the conversion?
122         if (IntrUserVTy->getElementCount().getKnownMinValue() >
123             PTrueVTy->getElementCount().getKnownMinValue())
124           // This is a promoted ptrue.
125           return true;
126       }
127     }
128   }
129 
130   // If no matching calls were found, this is not a promoted ptrue.
131   return false;
132 }
133 
134 /// Attempts to coalesce ptrues in a basic block.
135 bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
136     BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) {
137   if (PTrues.size() <= 1)
138     return false;
139 
140   // Find the ptrue with the most lanes.
141   auto *MostEncompassingPTrue = *std::max_element(
142       PTrues.begin(), PTrues.end(), [](auto *PTrue1, auto *PTrue2) {
143         auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
144         auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
145         return PTrue1VTy->getElementCount().getKnownMinValue() <
146                PTrue2VTy->getElementCount().getKnownMinValue();
147       });
148 
149   // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving
150   // behind only the ptrues to be coalesced.
151   PTrues.remove(MostEncompassingPTrue);
152   PTrues.remove_if(isPTruePromoted);
153 
154   // Hoist MostEncompassingPTrue to the start of the basic block. It is always
155   // safe to do this, since ptrue intrinsic calls are guaranteed to have no
156   // predecessors.
157   MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt());
158 
159   LLVMContext &Ctx = BB.getContext();
160   IRBuilder<> Builder(Ctx);
161   Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
162 
163   auto *MostEncompassingPTrueVTy =
164       cast<VectorType>(MostEncompassingPTrue->getType());
165   auto *ConvertToSVBool = Builder.CreateIntrinsic(
166       Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
167       {MostEncompassingPTrue});
168 
169   bool ConvertFromCreated = false;
170   for (auto *PTrue : PTrues) {
171     auto *PTrueVTy = cast<VectorType>(PTrue->getType());
172 
173     // Only create the converts if the types are not already the same, otherwise
174     // just use the most encompassing ptrue.
175     if (MostEncompassingPTrueVTy != PTrueVTy) {
176       ConvertFromCreated = true;
177 
178       Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
179       auto *ConvertFromSVBool =
180           Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
181                                   {PTrueVTy}, {ConvertToSVBool});
182       PTrue->replaceAllUsesWith(ConvertFromSVBool);
183     } else
184       PTrue->replaceAllUsesWith(MostEncompassingPTrue);
185 
186     PTrue->eraseFromParent();
187   }
188 
189   // We never used the ConvertTo so remove it
190   if (!ConvertFromCreated)
191     ConvertToSVBool->eraseFromParent();
192 
193   return true;
194 }
195 
196 /// The goal of this function is to remove redundant calls to the SVE ptrue
197 /// intrinsic in each basic block within the given functions.
198 ///
199 /// SVE ptrues have two representations in LLVM IR:
200 /// - a logical representation -- an arbitrary-width scalable vector of i1s,
201 ///   i.e. <vscale x N x i1>.
202 /// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element
203 ///   scalable vector of i1s, i.e. <vscale x 16 x i1>.
204 ///
205 /// The SVE ptrue intrinsic is used to create a logical representation of an SVE
206 /// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If
207 /// P1 creates a logical SVE predicate that is at least as wide as the logical
208 /// SVE predicate created by P2, then all of the bits that are true in the
209 /// physical representation of P2 are necessarily also true in the physical
210 /// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to
211 /// P2 is redundant and can be replaced by an SVE reinterpret of P1 via
212 /// convert.{to,from}.svbool.
213 ///
214 /// Currently, this pass only coalesces calls to SVE ptrue intrinsics
215 /// if they match the following conditions:
216 ///
217 /// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns.
218 ///   SV_ALL indicates that all bits of the predicate vector are to be set to
219 ///   true. SV_POW2 indicates that all bits of the predicate vector up to the
220 ///   largest power-of-two are to be set to true.
221 /// - the result of the call to the intrinsic is not promoted to a wider
222 ///   predicate. In this case, keeping the extra ptrue leads to better codegen
223 ///   -- coalescing here would create an irreducible chain of SVE reinterprets
224 ///   via convert.{to,from}.svbool.
225 ///
226 /// EXAMPLE:
227 ///
228 ///     %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL)
229 ///     ; Logical:  <1, 1, 1, 1, 1, 1, 1, 1>
230 ///     ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0>
231 ///     ...
232 ///
233 ///     %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL)
234 ///     ; Logical:  <1, 1, 1, 1>
235 ///     ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0>
236 ///     ...
237 ///
238 /// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance:
239 ///
240 ///     %1 = <vscale x 8 x i1> ptrue(i32 i31)
241 ///     %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1)
242 ///     %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2)
243 ///
244 bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
245     SmallSetVector<Function *, 4> &Functions) {
246   bool Changed = false;
247 
248   for (auto *F : Functions) {
249     for (auto &BB : *F) {
250       SmallSetVector<IntrinsicInst *, 4> SVAllPTrues;
251       SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues;
252 
253       // For each basic block, collect the used ptrues and try to coalesce them.
254       for (Instruction &I : BB) {
255         if (I.use_empty())
256           continue;
257 
258         auto *IntrI = dyn_cast<IntrinsicInst>(&I);
259         if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
260           continue;
261 
262         const auto PTruePattern =
263             cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
264 
265         if (PTruePattern == AArch64SVEPredPattern::all)
266           SVAllPTrues.insert(IntrI);
267         if (PTruePattern == AArch64SVEPredPattern::pow2)
268           SVPow2PTrues.insert(IntrI);
269       }
270 
271       Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
272       Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
273     }
274   }
275 
276   return Changed;
277 }
278 
279 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
280 // scalable stores as late as possible
281 bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) {
282   auto *F = I->getFunction();
283   auto Attr = F->getFnAttribute(Attribute::VScaleRange);
284   if (!Attr.isValid())
285     return false;
286 
287   unsigned MinVScale = Attr.getVScaleRangeMin();
288   std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
289   // The transform needs to know the exact runtime length of scalable vectors
290   if (!MaxVScale || MinVScale != MaxVScale)
291     return false;
292 
293   auto *PredType =
294       ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
295   auto *FixedPredType =
296       FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
297 
298   // If we have a store..
299   auto *Store = dyn_cast<StoreInst>(I);
300   if (!Store || !Store->isSimple())
301     return false;
302 
303   // ..that is storing a predicate vector sized worth of bits..
304   if (Store->getOperand(0)->getType() != FixedPredType)
305     return false;
306 
307   // ..where the value stored comes from a vector extract..
308   auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0));
309   if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
310     return false;
311 
312   // ..that is extracting from index 0..
313   if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
314     return false;
315 
316   // ..where the value being extract from comes from a bitcast
317   auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
318   if (!BitCast)
319     return false;
320 
321   // ..and the bitcast is casting from predicate type
322   if (BitCast->getOperand(0)->getType() != PredType)
323     return false;
324 
325   IRBuilder<> Builder(I->getContext());
326   Builder.SetInsertPoint(I);
327 
328   auto *PtrBitCast = Builder.CreateBitCast(
329       Store->getPointerOperand(),
330       PredType->getPointerTo(Store->getPointerAddressSpace()));
331   Builder.CreateStore(BitCast->getOperand(0), PtrBitCast);
332 
333   Store->eraseFromParent();
334   if (IntrI->getNumUses() == 0)
335     IntrI->eraseFromParent();
336   if (BitCast->getNumUses() == 0)
337     BitCast->eraseFromParent();
338 
339   return true;
340 }
341 
342 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
343 // scalable loads as late as possible
344 bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) {
345   auto *F = I->getFunction();
346   auto Attr = F->getFnAttribute(Attribute::VScaleRange);
347   if (!Attr.isValid())
348     return false;
349 
350   unsigned MinVScale = Attr.getVScaleRangeMin();
351   std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
352   // The transform needs to know the exact runtime length of scalable vectors
353   if (!MaxVScale || MinVScale != MaxVScale)
354     return false;
355 
356   auto *PredType =
357       ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
358   auto *FixedPredType =
359       FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
360 
361   // If we have a bitcast..
362   auto *BitCast = dyn_cast<BitCastInst>(I);
363   if (!BitCast || BitCast->getType() != PredType)
364     return false;
365 
366   // ..whose operand is a vector_insert..
367   auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
368   if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
369     return false;
370 
371   // ..that is inserting into index zero of an undef vector..
372   if (!isa<UndefValue>(IntrI->getOperand(0)) ||
373       !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
374     return false;
375 
376   // ..where the value inserted comes from a load..
377   auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
378   if (!Load || !Load->isSimple())
379     return false;
380 
381   // ..that is loading a predicate vector sized worth of bits..
382   if (Load->getType() != FixedPredType)
383     return false;
384 
385   IRBuilder<> Builder(I->getContext());
386   Builder.SetInsertPoint(Load);
387 
388   auto *PtrBitCast = Builder.CreateBitCast(
389       Load->getPointerOperand(),
390       PredType->getPointerTo(Load->getPointerAddressSpace()));
391   auto *LoadPred = Builder.CreateLoad(PredType, PtrBitCast);
392 
393   BitCast->replaceAllUsesWith(LoadPred);
394   BitCast->eraseFromParent();
395   if (IntrI->getNumUses() == 0)
396     IntrI->eraseFromParent();
397   if (Load->getNumUses() == 0)
398     Load->eraseFromParent();
399 
400   return true;
401 }
402 
403 bool SVEIntrinsicOpts::optimizeInstructions(
404     SmallSetVector<Function *, 4> &Functions) {
405   bool Changed = false;
406 
407   for (auto *F : Functions) {
408     DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
409 
410     // Traverse the DT with an rpo walk so we see defs before uses, allowing
411     // simplification to be done incrementally.
412     BasicBlock *Root = DT->getRoot();
413     ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
414     for (auto *BB : RPOT) {
415       for (Instruction &I : make_early_inc_range(*BB)) {
416         switch (I.getOpcode()) {
417         case Instruction::Store:
418           Changed |= optimizePredicateStore(&I);
419           break;
420         case Instruction::BitCast:
421           Changed |= optimizePredicateLoad(&I);
422           break;
423         }
424       }
425     }
426   }
427 
428   return Changed;
429 }
430 
431 bool SVEIntrinsicOpts::optimizeFunctions(
432     SmallSetVector<Function *, 4> &Functions) {
433   bool Changed = false;
434 
435   Changed |= optimizePTrueIntrinsicCalls(Functions);
436   Changed |= optimizeInstructions(Functions);
437 
438   return Changed;
439 }
440 
441 bool SVEIntrinsicOpts::runOnModule(Module &M) {
442   bool Changed = false;
443   SmallSetVector<Function *, 4> Functions;
444 
445   // Check for SVE intrinsic declarations first so that we only iterate over
446   // relevant functions. Where an appropriate declaration is found, store the
447   // function(s) where it is used so we can target these only.
448   for (auto &F : M.getFunctionList()) {
449     if (!F.isDeclaration())
450       continue;
451 
452     switch (F.getIntrinsicID()) {
453     case Intrinsic::vector_extract:
454     case Intrinsic::vector_insert:
455     case Intrinsic::aarch64_sve_ptrue:
456       for (User *U : F.users())
457         Functions.insert(cast<Instruction>(U)->getFunction());
458       break;
459     default:
460       break;
461     }
462   }
463 
464   if (!Functions.empty())
465     Changed |= optimizeFunctions(Functions);
466 
467   return Changed;
468 }
469