1 //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass interleaves around sext/zext/trunc instructions. MVE does not have
10 // a single sext/zext or trunc instruction that takes the bottom half of a
11 // vector and extends to a full width, like NEON has with MOVL. Instead it is
12 // expected that this happens through top/bottom instructions. So the MVE
13 // equivalent VMOVLT/B instructions take either the even or odd elements of the
14 // input and extend them to the larger type, producing a vector with half the
15 // number of elements each of double the bitwidth. As there is no simple
16 // instruction, we often have to turn sext/zext/trunc into a series of lane
17 // moves (or stack loads/stores, which we do not do yet).
18 //
19 // This pass takes vector code that starts at truncs, looks for interconnected
20 // blobs of operations that end with sext/zext (or constants/splats) of the
21 // form:
22 //   %sa = sext v8i16 %a to v8i32
23 //   %sb = sext v8i16 %b to v8i32
24 //   %add = add v8i32 %sa, %sb
25 //   %r = trunc %add to v8i16
26 // And adds shuffles to allow the use of VMOVL/VMOVN instrctions:
27 //   %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7>
28 //   %sa = sext v8i16 %sha to v8i32
29 //   %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7>
30 //   %sb = sext v8i16 %shb to v8i32
31 //   %add = add v8i32 %sa, %sb
32 //   %r = trunc %add to v8i16
33 //   %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7>
34 // Which can then be split and lowered to MVE instructions efficiently:
35 //   %sa_b = VMOVLB.s16 %a
36 //   %sa_t = VMOVLT.s16 %a
37 //   %sb_b = VMOVLB.s16 %b
38 //   %sb_t = VMOVLT.s16 %b
39 //   %add_b = VADD.i32 %sa_b, %sb_b
40 //   %add_t = VADD.i32 %sa_t, %sb_t
41 //   %r = VMOVNT.i16 %add_b, %add_t
42 //
43 //===----------------------------------------------------------------------===//
44 
45 #include "ARM.h"
46 #include "ARMBaseInstrInfo.h"
47 #include "ARMSubtarget.h"
48 #include "llvm/ADT/SetVector.h"
49 #include "llvm/Analysis/TargetTransformInfo.h"
50 #include "llvm/CodeGen/TargetLowering.h"
51 #include "llvm/CodeGen/TargetPassConfig.h"
52 #include "llvm/CodeGen/TargetSubtargetInfo.h"
53 #include "llvm/IR/BasicBlock.h"
54 #include "llvm/IR/Constant.h"
55 #include "llvm/IR/Constants.h"
56 #include "llvm/IR/DerivedTypes.h"
57 #include "llvm/IR/Function.h"
58 #include "llvm/IR/IRBuilder.h"
59 #include "llvm/IR/InstIterator.h"
60 #include "llvm/IR/InstrTypes.h"
61 #include "llvm/IR/Instruction.h"
62 #include "llvm/IR/Instructions.h"
63 #include "llvm/IR/IntrinsicInst.h"
64 #include "llvm/IR/Intrinsics.h"
65 #include "llvm/IR/IntrinsicsARM.h"
66 #include "llvm/IR/PatternMatch.h"
67 #include "llvm/IR/Type.h"
68 #include "llvm/IR/Value.h"
69 #include "llvm/InitializePasses.h"
70 #include "llvm/Pass.h"
71 #include "llvm/Support/Casting.h"
72 #include <algorithm>
73 #include <cassert>
74 
75 using namespace llvm;
76 
77 #define DEBUG_TYPE "mve-laneinterleave"
78 
79 cl::opt<bool> EnableInterleave(
80     "enable-mve-interleave", cl::Hidden, cl::init(true),
81     cl::desc("Enable interleave MVE vector operation lowering"));
82 
83 namespace {
84 
85 class MVELaneInterleaving : public FunctionPass {
86 public:
87   static char ID; // Pass identification, replacement for typeid
88 
89   explicit MVELaneInterleaving() : FunctionPass(ID) {
90     initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry());
91   }
92 
93   bool runOnFunction(Function &F) override;
94 
95   StringRef getPassName() const override { return "MVE lane interleaving"; }
96 
97   void getAnalysisUsage(AnalysisUsage &AU) const override {
98     AU.setPreservesCFG();
99     AU.addRequired<TargetPassConfig>();
100     FunctionPass::getAnalysisUsage(AU);
101   }
102 };
103 
104 } // end anonymous namespace
105 
106 char MVELaneInterleaving::ID = 0;
107 
108 INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false,
109                 false)
110 
111 Pass *llvm::createMVELaneInterleavingPass() {
112   return new MVELaneInterleaving();
113 }
114 
115 static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts,
116                                      SmallSetVector<Instruction *, 4> &Truncs) {
117   // This is not always beneficial to transform. Exts can be incorporated into
118   // loads, Truncs can be folded into stores.
119   // Truncs are usually the same number of instructions,
120   //  VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving
121   // Exts are unfortunately more instructions in the general case:
122   //  A=VLDRH.32; B=VLDRH.32;
123   // vs with interleaving:
124   //  T=VLDRH.16; A=VMOVNB T; B=VMOVNT T
125   // But those VMOVL may be folded into a VMULL.
126 
127   // But expensive extends/truncs are always good to remove. FPExts always
128   // involve extra VCVT's so are always considered to be beneficial to convert.
129   for (auto *E : Exts) {
130     if (isa<FPExtInst>(E) || !isa<LoadInst>(E->getOperand(0))) {
131       LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n");
132       return true;
133     }
134   }
135   for (auto *T : Truncs) {
136     if (T->hasOneUse() && !isa<StoreInst>(*T->user_begin())) {
137       LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n");
138       return true;
139     }
140   }
141 
142   // Otherwise, we know we have a load(ext), see if any of the Extends are a
143   // vmull. This is a simple heuristic and certainly not perfect.
144   for (auto *E : Exts) {
145     if (!E->hasOneUse() ||
146         cast<Instruction>(*E->user_begin())->getOpcode() != Instruction::Mul) {
147       LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n");
148       return false;
149     }
150   }
151   return true;
152 }
153 
154 static bool tryInterleave(Instruction *Start,
155                           SmallPtrSetImpl<Instruction *> &Visited) {
156   LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n");
157 
158   if (!isa<Instruction>(Start->getOperand(0)))
159     return false;
160 
161   // Look for connected operations starting from Ext's, terminating at Truncs.
162   std::vector<Instruction *> Worklist;
163   Worklist.push_back(Start);
164   Worklist.push_back(cast<Instruction>(Start->getOperand(0)));
165 
166   SmallSetVector<Instruction *, 4> Truncs;
167   SmallSetVector<Instruction *, 4> Reducts;
168   SmallSetVector<Instruction *, 4> Exts;
169   SmallSetVector<Use *, 4> OtherLeafs;
170   SmallSetVector<Instruction *, 4> Ops;
171 
172   while (!Worklist.empty()) {
173     Instruction *I = Worklist.back();
174     Worklist.pop_back();
175 
176     switch (I->getOpcode()) {
177     // Truncs
178     case Instruction::Trunc:
179     case Instruction::FPTrunc:
180       if (!Truncs.insert(I))
181         continue;
182       Visited.insert(I);
183       break;
184 
185     // Extend leafs
186     case Instruction::SExt:
187     case Instruction::ZExt:
188     case Instruction::FPExt:
189       if (Exts.count(I))
190         continue;
191       for (auto *Use : I->users())
192         Worklist.push_back(cast<Instruction>(Use));
193       Exts.insert(I);
194       break;
195 
196     case Instruction::Call: {
197       IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
198       if (!II)
199         return false;
200 
201       if (II->getIntrinsicID() == Intrinsic::vector_reduce_add) {
202         if (!Reducts.insert(I))
203           continue;
204         Visited.insert(I);
205         break;
206       }
207 
208       switch (II->getIntrinsicID()) {
209       case Intrinsic::abs:
210       case Intrinsic::smin:
211       case Intrinsic::smax:
212       case Intrinsic::umin:
213       case Intrinsic::umax:
214       case Intrinsic::sadd_sat:
215       case Intrinsic::ssub_sat:
216       case Intrinsic::uadd_sat:
217       case Intrinsic::usub_sat:
218       case Intrinsic::minnum:
219       case Intrinsic::maxnum:
220       case Intrinsic::fabs:
221       case Intrinsic::fma:
222       case Intrinsic::ceil:
223       case Intrinsic::floor:
224       case Intrinsic::rint:
225       case Intrinsic::round:
226       case Intrinsic::trunc:
227         break;
228       default:
229         return false;
230       }
231       [[fallthrough]]; // Fall through to treating these like an operator below.
232     }
233     // Binary/tertiary ops
234     case Instruction::Add:
235     case Instruction::Sub:
236     case Instruction::Mul:
237     case Instruction::AShr:
238     case Instruction::LShr:
239     case Instruction::Shl:
240     case Instruction::ICmp:
241     case Instruction::FCmp:
242     case Instruction::FAdd:
243     case Instruction::FMul:
244     case Instruction::Select:
245       if (!Ops.insert(I))
246         continue;
247 
248       for (Use &Op : I->operands()) {
249         if (!isa<FixedVectorType>(Op->getType()))
250           continue;
251         if (isa<Instruction>(Op))
252           Worklist.push_back(cast<Instruction>(&Op));
253         else
254           OtherLeafs.insert(&Op);
255       }
256 
257       for (auto *Use : I->users())
258         Worklist.push_back(cast<Instruction>(Use));
259       break;
260 
261     case Instruction::ShuffleVector:
262       // A shuffle of a splat is a splat.
263       if (cast<ShuffleVectorInst>(I)->isZeroEltSplat())
264         continue;
265       [[fallthrough]];
266 
267     default:
268       LLVM_DEBUG(dbgs() << "  Unhandled instruction: " << *I << "\n");
269       return false;
270     }
271   }
272 
273   if (Exts.empty() && OtherLeafs.empty())
274     return false;
275 
276   LLVM_DEBUG({
277     dbgs() << "Found group:\n  Exts:\n";
278     for (auto *I : Exts)
279       dbgs() << "  " << *I << "\n";
280     dbgs() << "  Ops:\n";
281     for (auto *I : Ops)
282       dbgs() << "  " << *I << "\n";
283     dbgs() << "  OtherLeafs:\n";
284     for (auto *I : OtherLeafs)
285       dbgs() << "  " << *I->get() << " of " << *I->getUser() << "\n";
286     dbgs() << "  Truncs:\n";
287     for (auto *I : Truncs)
288       dbgs() << "  " << *I << "\n";
289     dbgs() << "  Reducts:\n";
290     for (auto *I : Reducts)
291       dbgs() << "  " << *I << "\n";
292   });
293 
294   assert((!Truncs.empty() || !Reducts.empty()) &&
295          "Expected some truncs or reductions");
296   if (Truncs.empty() && Exts.empty())
297     return false;
298 
299   auto *VT = !Truncs.empty()
300                  ? cast<FixedVectorType>(Truncs[0]->getType())
301                  : cast<FixedVectorType>(Exts[0]->getOperand(0)->getType());
302   LLVM_DEBUG(dbgs() << "Using VT:" << *VT << "\n");
303 
304   // Check types
305   unsigned NumElts = VT->getNumElements();
306   unsigned BaseElts = VT->getScalarSizeInBits() == 16
307                           ? 8
308                           : (VT->getScalarSizeInBits() == 8 ? 16 : 0);
309   if (BaseElts == 0 || NumElts % BaseElts != 0) {
310     LLVM_DEBUG(dbgs() << "  Type is unsupported\n");
311     return false;
312   }
313   if (Start->getOperand(0)->getType()->getScalarSizeInBits() !=
314       VT->getScalarSizeInBits() * 2) {
315     LLVM_DEBUG(dbgs() << "  Type not double sized\n");
316     return false;
317   }
318   for (Instruction *I : Exts)
319     if (I->getOperand(0)->getType() != VT) {
320       LLVM_DEBUG(dbgs() << "  Wrong type on " << *I << "\n");
321       return false;
322     }
323   for (Instruction *I : Truncs)
324     if (I->getType() != VT) {
325       LLVM_DEBUG(dbgs() << "  Wrong type on " << *I << "\n");
326       return false;
327     }
328 
329   // Check that it looks beneficial
330   if (!isProfitableToInterleave(Exts, Truncs))
331     return false;
332   if (!Reducts.empty() && (Ops.empty() || all_of(Ops, [](Instruction *I) {
333                              return I->getOpcode() == Instruction::Mul ||
334                                     I->getOpcode() == Instruction::Select ||
335                                     I->getOpcode() == Instruction::ICmp;
336                            }))) {
337     LLVM_DEBUG(dbgs() << "Reduction does not look profitable\n");
338     return false;
339   }
340 
341   // Create new shuffles around the extends / truncs / other leaves.
342   IRBuilder<> Builder(Start);
343 
344   SmallVector<int, 16> LeafMask;
345   SmallVector<int, 16> TruncMask;
346   // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7   8, 10, 12, 14,  9, 11, 13, 15
347   // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7   8, 12,  9, 13, 10, 14, 11, 15
348   for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
349     for (unsigned i = 0; i < BaseElts / 2; i++)
350       LeafMask.push_back(Base + i * 2);
351     for (unsigned i = 0; i < BaseElts / 2; i++)
352       LeafMask.push_back(Base + i * 2 + 1);
353   }
354   for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
355     for (unsigned i = 0; i < BaseElts / 2; i++) {
356       TruncMask.push_back(Base + i);
357       TruncMask.push_back(Base + i + BaseElts / 2);
358     }
359   }
360 
361   for (Instruction *I : Exts) {
362     LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n");
363     Builder.SetInsertPoint(I);
364     Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask);
365     bool FPext = isa<FPExtInst>(I);
366     bool Sext = isa<SExtInst>(I);
367     Value *Ext = FPext ? Builder.CreateFPExt(Shuffle, I->getType())
368                        : Sext ? Builder.CreateSExt(Shuffle, I->getType())
369                               : Builder.CreateZExt(Shuffle, I->getType());
370     I->replaceAllUsesWith(Ext);
371     LLVM_DEBUG(dbgs() << "  with " << *Shuffle << "\n");
372   }
373 
374   for (Use *I : OtherLeafs) {
375     LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n");
376     Builder.SetInsertPoint(cast<Instruction>(I->getUser()));
377     Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask);
378     I->getUser()->setOperand(I->getOperandNo(), Shuffle);
379     LLVM_DEBUG(dbgs() << "  with " << *Shuffle << "\n");
380   }
381 
382   for (Instruction *I : Truncs) {
383     LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n");
384 
385     Builder.SetInsertPoint(I->getParent(), ++I->getIterator());
386     Value *Shuf = Builder.CreateShuffleVector(I, TruncMask);
387     I->replaceAllUsesWith(Shuf);
388     cast<Instruction>(Shuf)->setOperand(0, I);
389 
390     LLVM_DEBUG(dbgs() << "  with " << *Shuf << "\n");
391   }
392 
393   return true;
394 }
395 
396 // Add reductions are fairly common and associative, meaning we can start the
397 // interleaving from them and don't need to emit a shuffle.
398 static bool isAddReduction(Instruction &I) {
399   if (auto *II = dyn_cast<IntrinsicInst>(&I))
400     return II->getIntrinsicID() == Intrinsic::vector_reduce_add;
401   return false;
402 }
403 
404 bool MVELaneInterleaving::runOnFunction(Function &F) {
405   if (!EnableInterleave)
406     return false;
407   auto &TPC = getAnalysis<TargetPassConfig>();
408   auto &TM = TPC.getTM<TargetMachine>();
409   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
410   if (!ST->hasMVEIntegerOps())
411     return false;
412 
413   bool Changed = false;
414 
415   SmallPtrSet<Instruction *, 16> Visited;
416   for (Instruction &I : reverse(instructions(F))) {
417     if (((I.getType()->isVectorTy() &&
418           (isa<TruncInst>(I) || isa<FPTruncInst>(I))) ||
419          isAddReduction(I)) &&
420         !Visited.count(&I))
421       Changed |= tryInterleave(&I, Visited);
422   }
423 
424   return Changed;
425 }
426