1 //===- AArch64FalkorHWPFFix.cpp - Avoid HW prefetcher pitfalls on Falkor --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file For Falkor, we want to avoid HW prefetcher instruction tag collisions
9 /// that may inhibit the HW prefetching.  This is done in two steps.  Before
10 /// ISel, we mark strided loads (i.e. those that will likely benefit from
11 /// prefetching) with metadata.  Then, after opcodes have been finalized, we
12 /// insert MOVs and re-write loads to prevent unintentional tag collisions.
13 // ===---------------------------------------------------------------------===//
14 
15 #include "AArch64.h"
16 #include "AArch64InstrInfo.h"
17 #include "AArch64Subtarget.h"
18 #include "AArch64TargetMachine.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/DepthFirstIterator.h"
21 #include "llvm/ADT/None.h"
22 #include "llvm/ADT/Optional.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/Analysis/LoopInfo.h"
26 #include "llvm/Analysis/ScalarEvolution.h"
27 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
28 #include "llvm/CodeGen/LiveRegUnits.h"
29 #include "llvm/CodeGen/MachineBasicBlock.h"
30 #include "llvm/CodeGen/MachineFunction.h"
31 #include "llvm/CodeGen/MachineFunctionPass.h"
32 #include "llvm/CodeGen/MachineInstr.h"
33 #include "llvm/CodeGen/MachineInstrBuilder.h"
34 #include "llvm/CodeGen/MachineLoopInfo.h"
35 #include "llvm/CodeGen/MachineOperand.h"
36 #include "llvm/CodeGen/MachineRegisterInfo.h"
37 #include "llvm/CodeGen/TargetPassConfig.h"
38 #include "llvm/CodeGen/TargetRegisterInfo.h"
39 #include "llvm/IR/DebugLoc.h"
40 #include "llvm/IR/Dominators.h"
41 #include "llvm/IR/Function.h"
42 #include "llvm/IR/Instruction.h"
43 #include "llvm/IR/Instructions.h"
44 #include "llvm/IR/Metadata.h"
45 #include "llvm/Pass.h"
46 #include "llvm/Support/Casting.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Support/DebugCounter.h"
49 #include "llvm/Support/raw_ostream.h"
50 #include <cassert>
51 #include <iterator>
52 #include <utility>
53 
54 using namespace llvm;
55 
56 #define DEBUG_TYPE "falkor-hwpf-fix"
57 
58 STATISTIC(NumStridedLoadsMarked, "Number of strided loads marked");
59 STATISTIC(NumCollisionsAvoided,
60           "Number of HW prefetch tag collisions avoided");
61 STATISTIC(NumCollisionsNotAvoided,
62           "Number of HW prefetch tag collisions not avoided due to lack of registers");
63 DEBUG_COUNTER(FixCounter, "falkor-hwpf",
64               "Controls which tag collisions are avoided");
65 
66 namespace {
67 
68 class FalkorMarkStridedAccesses {
69 public:
FalkorMarkStridedAccesses(LoopInfo & LI,ScalarEvolution & SE)70   FalkorMarkStridedAccesses(LoopInfo &LI, ScalarEvolution &SE)
71       : LI(LI), SE(SE) {}
72 
73   bool run();
74 
75 private:
76   bool runOnLoop(Loop &L);
77 
78   LoopInfo &LI;
79   ScalarEvolution &SE;
80 };
81 
82 class FalkorMarkStridedAccessesLegacy : public FunctionPass {
83 public:
84   static char ID; // Pass ID, replacement for typeid
85 
FalkorMarkStridedAccessesLegacy()86   FalkorMarkStridedAccessesLegacy() : FunctionPass(ID) {
87     initializeFalkorMarkStridedAccessesLegacyPass(
88         *PassRegistry::getPassRegistry());
89   }
90 
getAnalysisUsage(AnalysisUsage & AU) const91   void getAnalysisUsage(AnalysisUsage &AU) const override {
92     AU.addRequired<TargetPassConfig>();
93     AU.addPreserved<DominatorTreeWrapperPass>();
94     AU.addRequired<LoopInfoWrapperPass>();
95     AU.addPreserved<LoopInfoWrapperPass>();
96     AU.addRequired<ScalarEvolutionWrapperPass>();
97     AU.addPreserved<ScalarEvolutionWrapperPass>();
98   }
99 
100   bool runOnFunction(Function &F) override;
101 };
102 
103 } // end anonymous namespace
104 
105 char FalkorMarkStridedAccessesLegacy::ID = 0;
106 
107 INITIALIZE_PASS_BEGIN(FalkorMarkStridedAccessesLegacy, DEBUG_TYPE,
108                       "Falkor HW Prefetch Fix", false, false)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)109 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
110 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
111 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
112 INITIALIZE_PASS_END(FalkorMarkStridedAccessesLegacy, DEBUG_TYPE,
113                     "Falkor HW Prefetch Fix", false, false)
114 
115 FunctionPass *llvm::createFalkorMarkStridedAccessesPass() {
116   return new FalkorMarkStridedAccessesLegacy();
117 }
118 
runOnFunction(Function & F)119 bool FalkorMarkStridedAccessesLegacy::runOnFunction(Function &F) {
120   TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
121   const AArch64Subtarget *ST =
122       TPC.getTM<AArch64TargetMachine>().getSubtargetImpl(F);
123   if (ST->getProcFamily() != AArch64Subtarget::Falkor)
124     return false;
125 
126   if (skipFunction(F))
127     return false;
128 
129   LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
130   ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
131 
132   FalkorMarkStridedAccesses LDP(LI, SE);
133   return LDP.run();
134 }
135 
run()136 bool FalkorMarkStridedAccesses::run() {
137   bool MadeChange = false;
138 
139   for (Loop *L : LI)
140     for (auto LIt = df_begin(L), LE = df_end(L); LIt != LE; ++LIt)
141       MadeChange |= runOnLoop(**LIt);
142 
143   return MadeChange;
144 }
145 
runOnLoop(Loop & L)146 bool FalkorMarkStridedAccesses::runOnLoop(Loop &L) {
147   // Only mark strided loads in the inner-most loop
148   if (!L.empty())
149     return false;
150 
151   bool MadeChange = false;
152 
153   for (BasicBlock *BB : L.blocks()) {
154     for (Instruction &I : *BB) {
155       LoadInst *LoadI = dyn_cast<LoadInst>(&I);
156       if (!LoadI)
157         continue;
158 
159       Value *PtrValue = LoadI->getPointerOperand();
160       if (L.isLoopInvariant(PtrValue))
161         continue;
162 
163       const SCEV *LSCEV = SE.getSCEV(PtrValue);
164       const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
165       if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
166         continue;
167 
168       LoadI->setMetadata(FALKOR_STRIDED_ACCESS_MD,
169                          MDNode::get(LoadI->getContext(), {}));
170       ++NumStridedLoadsMarked;
171       LLVM_DEBUG(dbgs() << "Load: " << I << " marked as strided\n");
172       MadeChange = true;
173     }
174   }
175 
176   return MadeChange;
177 }
178 
179 namespace {
180 
181 class FalkorHWPFFix : public MachineFunctionPass {
182 public:
183   static char ID;
184 
FalkorHWPFFix()185   FalkorHWPFFix() : MachineFunctionPass(ID) {
186     initializeFalkorHWPFFixPass(*PassRegistry::getPassRegistry());
187   }
188 
189   bool runOnMachineFunction(MachineFunction &Fn) override;
190 
getAnalysisUsage(AnalysisUsage & AU) const191   void getAnalysisUsage(AnalysisUsage &AU) const override {
192     AU.setPreservesCFG();
193     AU.addRequired<MachineLoopInfo>();
194     MachineFunctionPass::getAnalysisUsage(AU);
195   }
196 
getRequiredProperties() const197   MachineFunctionProperties getRequiredProperties() const override {
198     return MachineFunctionProperties().set(
199         MachineFunctionProperties::Property::NoVRegs);
200   }
201 
202 private:
203   void runOnLoop(MachineLoop &L, MachineFunction &Fn);
204 
205   const AArch64InstrInfo *TII;
206   const TargetRegisterInfo *TRI;
207   DenseMap<unsigned, SmallVector<MachineInstr *, 4>> TagMap;
208   bool Modified;
209 };
210 
211 /// Bits from load opcodes used to compute HW prefetcher instruction tags.
212 struct LoadInfo {
213   LoadInfo() = default;
214 
215   Register DestReg;
216   Register BaseReg;
217   int BaseRegIdx = -1;
218   const MachineOperand *OffsetOpnd = nullptr;
219   bool IsPrePost = false;
220 };
221 
222 } // end anonymous namespace
223 
224 char FalkorHWPFFix::ID = 0;
225 
226 INITIALIZE_PASS_BEGIN(FalkorHWPFFix, "falkor-hwpf-fix-late",
227                       "Falkor HW Prefetch Fix Late Phase", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)228 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
229 INITIALIZE_PASS_END(FalkorHWPFFix, "falkor-hwpf-fix-late",
230                     "Falkor HW Prefetch Fix Late Phase", false, false)
231 
232 static unsigned makeTag(unsigned Dest, unsigned Base, unsigned Offset) {
233   return (Dest & 0xf) | ((Base & 0xf) << 4) | ((Offset & 0x3f) << 8);
234 }
235 
getLoadInfo(const MachineInstr & MI)236 static Optional<LoadInfo> getLoadInfo(const MachineInstr &MI) {
237   int DestRegIdx;
238   int BaseRegIdx;
239   int OffsetIdx;
240   bool IsPrePost;
241 
242   switch (MI.getOpcode()) {
243   default:
244     return None;
245 
246   case AArch64::LD1i64:
247   case AArch64::LD2i64:
248     DestRegIdx = 0;
249     BaseRegIdx = 3;
250     OffsetIdx = -1;
251     IsPrePost = false;
252     break;
253 
254   case AArch64::LD1i8:
255   case AArch64::LD1i16:
256   case AArch64::LD1i32:
257   case AArch64::LD2i8:
258   case AArch64::LD2i16:
259   case AArch64::LD2i32:
260   case AArch64::LD3i8:
261   case AArch64::LD3i16:
262   case AArch64::LD3i32:
263   case AArch64::LD3i64:
264   case AArch64::LD4i8:
265   case AArch64::LD4i16:
266   case AArch64::LD4i32:
267   case AArch64::LD4i64:
268     DestRegIdx = -1;
269     BaseRegIdx = 3;
270     OffsetIdx = -1;
271     IsPrePost = false;
272     break;
273 
274   case AArch64::LD1Onev1d:
275   case AArch64::LD1Onev2s:
276   case AArch64::LD1Onev4h:
277   case AArch64::LD1Onev8b:
278   case AArch64::LD1Onev2d:
279   case AArch64::LD1Onev4s:
280   case AArch64::LD1Onev8h:
281   case AArch64::LD1Onev16b:
282   case AArch64::LD1Rv1d:
283   case AArch64::LD1Rv2s:
284   case AArch64::LD1Rv4h:
285   case AArch64::LD1Rv8b:
286   case AArch64::LD1Rv2d:
287   case AArch64::LD1Rv4s:
288   case AArch64::LD1Rv8h:
289   case AArch64::LD1Rv16b:
290     DestRegIdx = 0;
291     BaseRegIdx = 1;
292     OffsetIdx = -1;
293     IsPrePost = false;
294     break;
295 
296   case AArch64::LD1Twov1d:
297   case AArch64::LD1Twov2s:
298   case AArch64::LD1Twov4h:
299   case AArch64::LD1Twov8b:
300   case AArch64::LD1Twov2d:
301   case AArch64::LD1Twov4s:
302   case AArch64::LD1Twov8h:
303   case AArch64::LD1Twov16b:
304   case AArch64::LD1Threev1d:
305   case AArch64::LD1Threev2s:
306   case AArch64::LD1Threev4h:
307   case AArch64::LD1Threev8b:
308   case AArch64::LD1Threev2d:
309   case AArch64::LD1Threev4s:
310   case AArch64::LD1Threev8h:
311   case AArch64::LD1Threev16b:
312   case AArch64::LD1Fourv1d:
313   case AArch64::LD1Fourv2s:
314   case AArch64::LD1Fourv4h:
315   case AArch64::LD1Fourv8b:
316   case AArch64::LD1Fourv2d:
317   case AArch64::LD1Fourv4s:
318   case AArch64::LD1Fourv8h:
319   case AArch64::LD1Fourv16b:
320   case AArch64::LD2Twov2s:
321   case AArch64::LD2Twov4s:
322   case AArch64::LD2Twov8b:
323   case AArch64::LD2Twov2d:
324   case AArch64::LD2Twov4h:
325   case AArch64::LD2Twov8h:
326   case AArch64::LD2Twov16b:
327   case AArch64::LD2Rv1d:
328   case AArch64::LD2Rv2s:
329   case AArch64::LD2Rv4s:
330   case AArch64::LD2Rv8b:
331   case AArch64::LD2Rv2d:
332   case AArch64::LD2Rv4h:
333   case AArch64::LD2Rv8h:
334   case AArch64::LD2Rv16b:
335   case AArch64::LD3Threev2s:
336   case AArch64::LD3Threev4h:
337   case AArch64::LD3Threev8b:
338   case AArch64::LD3Threev2d:
339   case AArch64::LD3Threev4s:
340   case AArch64::LD3Threev8h:
341   case AArch64::LD3Threev16b:
342   case AArch64::LD3Rv1d:
343   case AArch64::LD3Rv2s:
344   case AArch64::LD3Rv4h:
345   case AArch64::LD3Rv8b:
346   case AArch64::LD3Rv2d:
347   case AArch64::LD3Rv4s:
348   case AArch64::LD3Rv8h:
349   case AArch64::LD3Rv16b:
350   case AArch64::LD4Fourv2s:
351   case AArch64::LD4Fourv4h:
352   case AArch64::LD4Fourv8b:
353   case AArch64::LD4Fourv2d:
354   case AArch64::LD4Fourv4s:
355   case AArch64::LD4Fourv8h:
356   case AArch64::LD4Fourv16b:
357   case AArch64::LD4Rv1d:
358   case AArch64::LD4Rv2s:
359   case AArch64::LD4Rv4h:
360   case AArch64::LD4Rv8b:
361   case AArch64::LD4Rv2d:
362   case AArch64::LD4Rv4s:
363   case AArch64::LD4Rv8h:
364   case AArch64::LD4Rv16b:
365     DestRegIdx = -1;
366     BaseRegIdx = 1;
367     OffsetIdx = -1;
368     IsPrePost = false;
369     break;
370 
371   case AArch64::LD1i64_POST:
372   case AArch64::LD2i64_POST:
373     DestRegIdx = 1;
374     BaseRegIdx = 4;
375     OffsetIdx = 5;
376     IsPrePost = true;
377     break;
378 
379   case AArch64::LD1i8_POST:
380   case AArch64::LD1i16_POST:
381   case AArch64::LD1i32_POST:
382   case AArch64::LD2i8_POST:
383   case AArch64::LD2i16_POST:
384   case AArch64::LD2i32_POST:
385   case AArch64::LD3i8_POST:
386   case AArch64::LD3i16_POST:
387   case AArch64::LD3i32_POST:
388   case AArch64::LD3i64_POST:
389   case AArch64::LD4i8_POST:
390   case AArch64::LD4i16_POST:
391   case AArch64::LD4i32_POST:
392   case AArch64::LD4i64_POST:
393     DestRegIdx = -1;
394     BaseRegIdx = 4;
395     OffsetIdx = 5;
396     IsPrePost = true;
397     break;
398 
399   case AArch64::LD1Onev1d_POST:
400   case AArch64::LD1Onev2s_POST:
401   case AArch64::LD1Onev4h_POST:
402   case AArch64::LD1Onev8b_POST:
403   case AArch64::LD1Onev2d_POST:
404   case AArch64::LD1Onev4s_POST:
405   case AArch64::LD1Onev8h_POST:
406   case AArch64::LD1Onev16b_POST:
407   case AArch64::LD1Rv1d_POST:
408   case AArch64::LD1Rv2s_POST:
409   case AArch64::LD1Rv4h_POST:
410   case AArch64::LD1Rv8b_POST:
411   case AArch64::LD1Rv2d_POST:
412   case AArch64::LD1Rv4s_POST:
413   case AArch64::LD1Rv8h_POST:
414   case AArch64::LD1Rv16b_POST:
415     DestRegIdx = 1;
416     BaseRegIdx = 2;
417     OffsetIdx = 3;
418     IsPrePost = true;
419     break;
420 
421   case AArch64::LD1Twov1d_POST:
422   case AArch64::LD1Twov2s_POST:
423   case AArch64::LD1Twov4h_POST:
424   case AArch64::LD1Twov8b_POST:
425   case AArch64::LD1Twov2d_POST:
426   case AArch64::LD1Twov4s_POST:
427   case AArch64::LD1Twov8h_POST:
428   case AArch64::LD1Twov16b_POST:
429   case AArch64::LD1Threev1d_POST:
430   case AArch64::LD1Threev2s_POST:
431   case AArch64::LD1Threev4h_POST:
432   case AArch64::LD1Threev8b_POST:
433   case AArch64::LD1Threev2d_POST:
434   case AArch64::LD1Threev4s_POST:
435   case AArch64::LD1Threev8h_POST:
436   case AArch64::LD1Threev16b_POST:
437   case AArch64::LD1Fourv1d_POST:
438   case AArch64::LD1Fourv2s_POST:
439   case AArch64::LD1Fourv4h_POST:
440   case AArch64::LD1Fourv8b_POST:
441   case AArch64::LD1Fourv2d_POST:
442   case AArch64::LD1Fourv4s_POST:
443   case AArch64::LD1Fourv8h_POST:
444   case AArch64::LD1Fourv16b_POST:
445   case AArch64::LD2Twov2s_POST:
446   case AArch64::LD2Twov4s_POST:
447   case AArch64::LD2Twov8b_POST:
448   case AArch64::LD2Twov2d_POST:
449   case AArch64::LD2Twov4h_POST:
450   case AArch64::LD2Twov8h_POST:
451   case AArch64::LD2Twov16b_POST:
452   case AArch64::LD2Rv1d_POST:
453   case AArch64::LD2Rv2s_POST:
454   case AArch64::LD2Rv4s_POST:
455   case AArch64::LD2Rv8b_POST:
456   case AArch64::LD2Rv2d_POST:
457   case AArch64::LD2Rv4h_POST:
458   case AArch64::LD2Rv8h_POST:
459   case AArch64::LD2Rv16b_POST:
460   case AArch64::LD3Threev2s_POST:
461   case AArch64::LD3Threev4h_POST:
462   case AArch64::LD3Threev8b_POST:
463   case AArch64::LD3Threev2d_POST:
464   case AArch64::LD3Threev4s_POST:
465   case AArch64::LD3Threev8h_POST:
466   case AArch64::LD3Threev16b_POST:
467   case AArch64::LD3Rv1d_POST:
468   case AArch64::LD3Rv2s_POST:
469   case AArch64::LD3Rv4h_POST:
470   case AArch64::LD3Rv8b_POST:
471   case AArch64::LD3Rv2d_POST:
472   case AArch64::LD3Rv4s_POST:
473   case AArch64::LD3Rv8h_POST:
474   case AArch64::LD3Rv16b_POST:
475   case AArch64::LD4Fourv2s_POST:
476   case AArch64::LD4Fourv4h_POST:
477   case AArch64::LD4Fourv8b_POST:
478   case AArch64::LD4Fourv2d_POST:
479   case AArch64::LD4Fourv4s_POST:
480   case AArch64::LD4Fourv8h_POST:
481   case AArch64::LD4Fourv16b_POST:
482   case AArch64::LD4Rv1d_POST:
483   case AArch64::LD4Rv2s_POST:
484   case AArch64::LD4Rv4h_POST:
485   case AArch64::LD4Rv8b_POST:
486   case AArch64::LD4Rv2d_POST:
487   case AArch64::LD4Rv4s_POST:
488   case AArch64::LD4Rv8h_POST:
489   case AArch64::LD4Rv16b_POST:
490     DestRegIdx = -1;
491     BaseRegIdx = 2;
492     OffsetIdx = 3;
493     IsPrePost = true;
494     break;
495 
496   case AArch64::LDRBBroW:
497   case AArch64::LDRBBroX:
498   case AArch64::LDRBBui:
499   case AArch64::LDRBroW:
500   case AArch64::LDRBroX:
501   case AArch64::LDRBui:
502   case AArch64::LDRDl:
503   case AArch64::LDRDroW:
504   case AArch64::LDRDroX:
505   case AArch64::LDRDui:
506   case AArch64::LDRHHroW:
507   case AArch64::LDRHHroX:
508   case AArch64::LDRHHui:
509   case AArch64::LDRHroW:
510   case AArch64::LDRHroX:
511   case AArch64::LDRHui:
512   case AArch64::LDRQl:
513   case AArch64::LDRQroW:
514   case AArch64::LDRQroX:
515   case AArch64::LDRQui:
516   case AArch64::LDRSBWroW:
517   case AArch64::LDRSBWroX:
518   case AArch64::LDRSBWui:
519   case AArch64::LDRSBXroW:
520   case AArch64::LDRSBXroX:
521   case AArch64::LDRSBXui:
522   case AArch64::LDRSHWroW:
523   case AArch64::LDRSHWroX:
524   case AArch64::LDRSHWui:
525   case AArch64::LDRSHXroW:
526   case AArch64::LDRSHXroX:
527   case AArch64::LDRSHXui:
528   case AArch64::LDRSWl:
529   case AArch64::LDRSWroW:
530   case AArch64::LDRSWroX:
531   case AArch64::LDRSWui:
532   case AArch64::LDRSl:
533   case AArch64::LDRSroW:
534   case AArch64::LDRSroX:
535   case AArch64::LDRSui:
536   case AArch64::LDRWl:
537   case AArch64::LDRWroW:
538   case AArch64::LDRWroX:
539   case AArch64::LDRWui:
540   case AArch64::LDRXl:
541   case AArch64::LDRXroW:
542   case AArch64::LDRXroX:
543   case AArch64::LDRXui:
544   case AArch64::LDURBBi:
545   case AArch64::LDURBi:
546   case AArch64::LDURDi:
547   case AArch64::LDURHHi:
548   case AArch64::LDURHi:
549   case AArch64::LDURQi:
550   case AArch64::LDURSBWi:
551   case AArch64::LDURSBXi:
552   case AArch64::LDURSHWi:
553   case AArch64::LDURSHXi:
554   case AArch64::LDURSWi:
555   case AArch64::LDURSi:
556   case AArch64::LDURWi:
557   case AArch64::LDURXi:
558     DestRegIdx = 0;
559     BaseRegIdx = 1;
560     OffsetIdx = 2;
561     IsPrePost = false;
562     break;
563 
564   case AArch64::LDRBBpost:
565   case AArch64::LDRBBpre:
566   case AArch64::LDRBpost:
567   case AArch64::LDRBpre:
568   case AArch64::LDRDpost:
569   case AArch64::LDRDpre:
570   case AArch64::LDRHHpost:
571   case AArch64::LDRHHpre:
572   case AArch64::LDRHpost:
573   case AArch64::LDRHpre:
574   case AArch64::LDRQpost:
575   case AArch64::LDRQpre:
576   case AArch64::LDRSBWpost:
577   case AArch64::LDRSBWpre:
578   case AArch64::LDRSBXpost:
579   case AArch64::LDRSBXpre:
580   case AArch64::LDRSHWpost:
581   case AArch64::LDRSHWpre:
582   case AArch64::LDRSHXpost:
583   case AArch64::LDRSHXpre:
584   case AArch64::LDRSWpost:
585   case AArch64::LDRSWpre:
586   case AArch64::LDRSpost:
587   case AArch64::LDRSpre:
588   case AArch64::LDRWpost:
589   case AArch64::LDRWpre:
590   case AArch64::LDRXpost:
591   case AArch64::LDRXpre:
592     DestRegIdx = 1;
593     BaseRegIdx = 2;
594     OffsetIdx = 3;
595     IsPrePost = true;
596     break;
597 
598   case AArch64::LDNPDi:
599   case AArch64::LDNPQi:
600   case AArch64::LDNPSi:
601   case AArch64::LDPQi:
602   case AArch64::LDPDi:
603   case AArch64::LDPSi:
604     DestRegIdx = -1;
605     BaseRegIdx = 2;
606     OffsetIdx = 3;
607     IsPrePost = false;
608     break;
609 
610   case AArch64::LDPSWi:
611   case AArch64::LDPWi:
612   case AArch64::LDPXi:
613     DestRegIdx = 0;
614     BaseRegIdx = 2;
615     OffsetIdx = 3;
616     IsPrePost = false;
617     break;
618 
619   case AArch64::LDPQpost:
620   case AArch64::LDPQpre:
621   case AArch64::LDPDpost:
622   case AArch64::LDPDpre:
623   case AArch64::LDPSpost:
624   case AArch64::LDPSpre:
625     DestRegIdx = -1;
626     BaseRegIdx = 3;
627     OffsetIdx = 4;
628     IsPrePost = true;
629     break;
630 
631   case AArch64::LDPSWpost:
632   case AArch64::LDPSWpre:
633   case AArch64::LDPWpost:
634   case AArch64::LDPWpre:
635   case AArch64::LDPXpost:
636   case AArch64::LDPXpre:
637     DestRegIdx = 1;
638     BaseRegIdx = 3;
639     OffsetIdx = 4;
640     IsPrePost = true;
641     break;
642   }
643 
644   // Loads from the stack pointer don't get prefetched.
645   unsigned BaseReg = MI.getOperand(BaseRegIdx).getReg();
646   if (BaseReg == AArch64::SP || BaseReg == AArch64::WSP)
647     return None;
648 
649   LoadInfo LI;
650   LI.DestReg = DestRegIdx == -1 ? Register() : MI.getOperand(DestRegIdx).getReg();
651   LI.BaseReg = BaseReg;
652   LI.BaseRegIdx = BaseRegIdx;
653   LI.OffsetOpnd = OffsetIdx == -1 ? nullptr : &MI.getOperand(OffsetIdx);
654   LI.IsPrePost = IsPrePost;
655   return LI;
656 }
657 
getTag(const TargetRegisterInfo * TRI,const MachineInstr & MI,const LoadInfo & LI)658 static Optional<unsigned> getTag(const TargetRegisterInfo *TRI,
659                                  const MachineInstr &MI, const LoadInfo &LI) {
660   unsigned Dest = LI.DestReg ? TRI->getEncodingValue(LI.DestReg) : 0;
661   unsigned Base = TRI->getEncodingValue(LI.BaseReg);
662   unsigned Off;
663   if (LI.OffsetOpnd == nullptr)
664     Off = 0;
665   else if (LI.OffsetOpnd->isGlobal() || LI.OffsetOpnd->isSymbol() ||
666            LI.OffsetOpnd->isCPI())
667     return None;
668   else if (LI.OffsetOpnd->isReg())
669     Off = (1 << 5) | TRI->getEncodingValue(LI.OffsetOpnd->getReg());
670   else
671     Off = LI.OffsetOpnd->getImm() >> 2;
672 
673   return makeTag(Dest, Base, Off);
674 }
675 
runOnLoop(MachineLoop & L,MachineFunction & Fn)676 void FalkorHWPFFix::runOnLoop(MachineLoop &L, MachineFunction &Fn) {
677   // Build the initial tag map for the whole loop.
678   TagMap.clear();
679   for (MachineBasicBlock *MBB : L.getBlocks())
680     for (MachineInstr &MI : *MBB) {
681       Optional<LoadInfo> LInfo = getLoadInfo(MI);
682       if (!LInfo)
683         continue;
684       Optional<unsigned> Tag = getTag(TRI, MI, *LInfo);
685       if (!Tag)
686         continue;
687       TagMap[*Tag].push_back(&MI);
688     }
689 
690   bool AnyCollisions = false;
691   for (auto &P : TagMap) {
692     auto Size = P.second.size();
693     if (Size > 1) {
694       for (auto *MI : P.second) {
695         if (TII->isStridedAccess(*MI)) {
696           AnyCollisions = true;
697           break;
698         }
699       }
700     }
701     if (AnyCollisions)
702       break;
703   }
704   // Nothing to fix.
705   if (!AnyCollisions)
706     return;
707 
708   MachineRegisterInfo &MRI = Fn.getRegInfo();
709 
710   // Go through all the basic blocks in the current loop and fix any streaming
711   // loads to avoid collisions with any other loads.
712   LiveRegUnits LR(*TRI);
713   for (MachineBasicBlock *MBB : L.getBlocks()) {
714     LR.clear();
715     LR.addLiveOuts(*MBB);
716     for (auto I = MBB->rbegin(); I != MBB->rend(); LR.stepBackward(*I), ++I) {
717       MachineInstr &MI = *I;
718       if (!TII->isStridedAccess(MI))
719         continue;
720 
721       Optional<LoadInfo> OptLdI = getLoadInfo(MI);
722       if (!OptLdI)
723         continue;
724       LoadInfo LdI = *OptLdI;
725       Optional<unsigned> OptOldTag = getTag(TRI, MI, LdI);
726       if (!OptOldTag)
727         continue;
728       auto &OldCollisions = TagMap[*OptOldTag];
729       if (OldCollisions.size() <= 1)
730         continue;
731 
732       bool Fixed = false;
733       LLVM_DEBUG(dbgs() << "Attempting to fix tag collision: " << MI);
734 
735       if (!DebugCounter::shouldExecute(FixCounter)) {
736         LLVM_DEBUG(dbgs() << "Skipping fix due to debug counter:\n  " << MI);
737         continue;
738       }
739 
740       // Add the non-base registers of MI as live so we don't use them as
741       // scratch registers.
742       for (unsigned OpI = 0, OpE = MI.getNumOperands(); OpI < OpE; ++OpI) {
743         if (OpI == static_cast<unsigned>(LdI.BaseRegIdx))
744           continue;
745         MachineOperand &MO = MI.getOperand(OpI);
746         if (MO.isReg() && MO.readsReg())
747           LR.addReg(MO.getReg());
748       }
749 
750       for (unsigned ScratchReg : AArch64::GPR64RegClass) {
751         if (!LR.available(ScratchReg) || MRI.isReserved(ScratchReg))
752           continue;
753 
754         LoadInfo NewLdI(LdI);
755         NewLdI.BaseReg = ScratchReg;
756         unsigned NewTag = *getTag(TRI, MI, NewLdI);
757         // Scratch reg tag would collide too, so don't use it.
758         if (TagMap.count(NewTag))
759           continue;
760 
761         LLVM_DEBUG(dbgs() << "Changing base reg to: "
762                           << printReg(ScratchReg, TRI) << '\n');
763 
764         // Rewrite:
765         //   Xd = LOAD Xb, off
766         // to:
767         //   Xc = MOV Xb
768         //   Xd = LOAD Xc, off
769         DebugLoc DL = MI.getDebugLoc();
770         BuildMI(*MBB, &MI, DL, TII->get(AArch64::ORRXrs), ScratchReg)
771             .addReg(AArch64::XZR)
772             .addReg(LdI.BaseReg)
773             .addImm(0);
774         MachineOperand &BaseOpnd = MI.getOperand(LdI.BaseRegIdx);
775         BaseOpnd.setReg(ScratchReg);
776 
777         // If the load does a pre/post increment, then insert a MOV after as
778         // well to update the real base register.
779         if (LdI.IsPrePost) {
780           LLVM_DEBUG(dbgs() << "Doing post MOV of incremented reg: "
781                             << printReg(ScratchReg, TRI) << '\n');
782           MI.getOperand(0).setReg(
783               ScratchReg); // Change tied operand pre/post update dest.
784           BuildMI(*MBB, std::next(MachineBasicBlock::iterator(MI)), DL,
785                   TII->get(AArch64::ORRXrs), LdI.BaseReg)
786               .addReg(AArch64::XZR)
787               .addReg(ScratchReg)
788               .addImm(0);
789         }
790 
791         for (int I = 0, E = OldCollisions.size(); I != E; ++I)
792           if (OldCollisions[I] == &MI) {
793             std::swap(OldCollisions[I], OldCollisions[E - 1]);
794             OldCollisions.pop_back();
795             break;
796           }
797 
798         // Update TagMap to reflect instruction changes to reduce the number
799         // of later MOVs to be inserted.  This needs to be done after
800         // OldCollisions is updated since it may be relocated by this
801         // insertion.
802         TagMap[NewTag].push_back(&MI);
803         ++NumCollisionsAvoided;
804         Fixed = true;
805         Modified = true;
806         break;
807       }
808       if (!Fixed)
809         ++NumCollisionsNotAvoided;
810     }
811   }
812 }
813 
runOnMachineFunction(MachineFunction & Fn)814 bool FalkorHWPFFix::runOnMachineFunction(MachineFunction &Fn) {
815   auto &ST = static_cast<const AArch64Subtarget &>(Fn.getSubtarget());
816   if (ST.getProcFamily() != AArch64Subtarget::Falkor)
817     return false;
818 
819   if (skipFunction(Fn.getFunction()))
820     return false;
821 
822   TII = static_cast<const AArch64InstrInfo *>(ST.getInstrInfo());
823   TRI = ST.getRegisterInfo();
824 
825   assert(TRI->trackLivenessAfterRegAlloc(Fn) &&
826          "Register liveness not available!");
827 
828   MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
829 
830   Modified = false;
831 
832   for (MachineLoop *I : LI)
833     for (auto L = df_begin(I), LE = df_end(I); L != LE; ++L)
834       // Only process inner-loops
835       if (L->empty())
836         runOnLoop(**L, Fn);
837 
838   return Modified;
839 }
840 
createFalkorHWPFFixPass()841 FunctionPass *llvm::createFalkorHWPFFixPass() { return new FalkorHWPFFix(); }
842