1 //===-- AArch64StackTaggingPreRA.cpp --- Stack Tagging for AArch64 -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 
10 #include "AArch64.h"
11 #include "AArch64MachineFunctionInfo.h"
12 #include "AArch64InstrInfo.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/ADT/Statistic.h"
15 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
16 #include "llvm/CodeGen/MachineFrameInfo.h"
17 #include "llvm/CodeGen/MachineFunction.h"
18 #include "llvm/CodeGen/MachineFunctionPass.h"
19 #include "llvm/CodeGen/MachineInstrBuilder.h"
20 #include "llvm/CodeGen/MachineLoopInfo.h"
21 #include "llvm/CodeGen/MachineRegisterInfo.h"
22 #include "llvm/CodeGen/MachineTraceMetrics.h"
23 #include "llvm/CodeGen/Passes.h"
24 #include "llvm/CodeGen/TargetInstrInfo.h"
25 #include "llvm/CodeGen/TargetRegisterInfo.h"
26 #include "llvm/CodeGen/TargetSubtargetInfo.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 using namespace llvm;
32 
33 #define DEBUG_TYPE "aarch64-stack-tagging-pre-ra"
34 
35 enum UncheckedLdStMode { UncheckedNever, UncheckedSafe, UncheckedAlways };
36 
37 cl::opt<UncheckedLdStMode> ClUncheckedLdSt(
38     "stack-tagging-unchecked-ld-st", cl::Hidden,
39     cl::init(UncheckedSafe),
40     cl::desc(
41         "Unconditionally apply unchecked-ld-st optimization (even for large "
42         "stack frames, or in the presence of variable sized allocas)."),
43     cl::values(
44         clEnumValN(UncheckedNever, "never", "never apply unchecked-ld-st"),
45         clEnumValN(
46             UncheckedSafe, "safe",
47             "apply unchecked-ld-st when the target is definitely within range"),
48         clEnumValN(UncheckedAlways, "always", "always apply unchecked-ld-st")));
49 
50 static cl::opt<bool>
51     ClFirstSlot("stack-tagging-first-slot-opt", cl::Hidden, cl::init(true),
52                 cl::desc("Apply first slot optimization for stack tagging "
53                          "(eliminate ADDG Rt, Rn, 0, 0)."));
54 
55 namespace {
56 
57 class AArch64StackTaggingPreRA : public MachineFunctionPass {
58   MachineFunction *MF;
59   AArch64FunctionInfo *AFI;
60   MachineFrameInfo *MFI;
61   MachineRegisterInfo *MRI;
62   const AArch64RegisterInfo *TRI;
63   const AArch64InstrInfo *TII;
64 
65   SmallVector<MachineInstr*, 16> ReTags;
66 
67 public:
68   static char ID;
69   AArch64StackTaggingPreRA() : MachineFunctionPass(ID) {
70     initializeAArch64StackTaggingPreRAPass(*PassRegistry::getPassRegistry());
71   }
72 
73   bool mayUseUncheckedLoadStore();
74   void uncheckUsesOf(unsigned TaggedReg, int FI);
75   void uncheckLoadsAndStores();
76   std::optional<int> findFirstSlotCandidate();
77 
78   bool runOnMachineFunction(MachineFunction &Func) override;
79   StringRef getPassName() const override {
80     return "AArch64 Stack Tagging PreRA";
81   }
82 
83   void getAnalysisUsage(AnalysisUsage &AU) const override {
84     AU.setPreservesCFG();
85     MachineFunctionPass::getAnalysisUsage(AU);
86   }
87 };
88 } // end anonymous namespace
89 
90 char AArch64StackTaggingPreRA::ID = 0;
91 
92 INITIALIZE_PASS_BEGIN(AArch64StackTaggingPreRA, "aarch64-stack-tagging-pre-ra",
93                       "AArch64 Stack Tagging PreRA Pass", false, false)
94 INITIALIZE_PASS_END(AArch64StackTaggingPreRA, "aarch64-stack-tagging-pre-ra",
95                     "AArch64 Stack Tagging PreRA Pass", false, false)
96 
97 FunctionPass *llvm::createAArch64StackTaggingPreRAPass() {
98   return new AArch64StackTaggingPreRA();
99 }
100 
101 static bool isUncheckedLoadOrStoreOpcode(unsigned Opcode) {
102   switch (Opcode) {
103   case AArch64::LDRBBui:
104   case AArch64::LDRHHui:
105   case AArch64::LDRWui:
106   case AArch64::LDRXui:
107 
108   case AArch64::LDRBui:
109   case AArch64::LDRHui:
110   case AArch64::LDRSui:
111   case AArch64::LDRDui:
112   case AArch64::LDRQui:
113 
114   case AArch64::LDRSHWui:
115   case AArch64::LDRSHXui:
116 
117   case AArch64::LDRSBWui:
118   case AArch64::LDRSBXui:
119 
120   case AArch64::LDRSWui:
121 
122   case AArch64::STRBBui:
123   case AArch64::STRHHui:
124   case AArch64::STRWui:
125   case AArch64::STRXui:
126 
127   case AArch64::STRBui:
128   case AArch64::STRHui:
129   case AArch64::STRSui:
130   case AArch64::STRDui:
131   case AArch64::STRQui:
132 
133   case AArch64::LDPWi:
134   case AArch64::LDPXi:
135   case AArch64::LDPSi:
136   case AArch64::LDPDi:
137   case AArch64::LDPQi:
138 
139   case AArch64::LDPSWi:
140 
141   case AArch64::STPWi:
142   case AArch64::STPXi:
143   case AArch64::STPSi:
144   case AArch64::STPDi:
145   case AArch64::STPQi:
146     return true;
147   default:
148     return false;
149   }
150 }
151 
152 bool AArch64StackTaggingPreRA::mayUseUncheckedLoadStore() {
153   if (ClUncheckedLdSt == UncheckedNever)
154     return false;
155   else if (ClUncheckedLdSt == UncheckedAlways)
156     return true;
157 
158   // This estimate can be improved if we had harder guarantees about stack frame
159   // layout. With LocalStackAllocation we can estimate SP offset to any
160   // preallocated slot. AArch64FrameLowering::orderFrameObjects could put tagged
161   // objects ahead of non-tagged ones, but that's not always desirable.
162   //
163   // Underestimating SP offset here may require the use of LDG to materialize
164   // the tagged address of the stack slot, along with a scratch register
165   // allocation (post-regalloc!).
166   //
167   // For now we do the safe thing here and require that the entire stack frame
168   // is within range of the shortest of the unchecked instructions.
169   unsigned FrameSize = 0;
170   for (unsigned i = 0, e = MFI->getObjectIndexEnd(); i != e; ++i)
171     FrameSize += MFI->getObjectSize(i);
172   bool EntireFrameReachableFromSP = FrameSize < 0xf00;
173   return !MFI->hasVarSizedObjects() && EntireFrameReachableFromSP;
174 }
175 
176 void AArch64StackTaggingPreRA::uncheckUsesOf(unsigned TaggedReg, int FI) {
177   for (MachineInstr &UseI :
178        llvm::make_early_inc_range(MRI->use_instructions(TaggedReg))) {
179     if (isUncheckedLoadOrStoreOpcode(UseI.getOpcode())) {
180       // FI operand is always the one before the immediate offset.
181       unsigned OpIdx = TII->getLoadStoreImmIdx(UseI.getOpcode()) - 1;
182       if (UseI.getOperand(OpIdx).isReg() &&
183           UseI.getOperand(OpIdx).getReg() == TaggedReg) {
184         UseI.getOperand(OpIdx).ChangeToFrameIndex(FI);
185         UseI.getOperand(OpIdx).setTargetFlags(AArch64II::MO_TAGGED);
186       }
187     } else if (UseI.isCopy() && UseI.getOperand(0).getReg().isVirtual()) {
188       uncheckUsesOf(UseI.getOperand(0).getReg(), FI);
189     }
190   }
191 }
192 
193 void AArch64StackTaggingPreRA::uncheckLoadsAndStores() {
194   for (auto *I : ReTags) {
195     Register TaggedReg = I->getOperand(0).getReg();
196     int FI = I->getOperand(1).getIndex();
197     uncheckUsesOf(TaggedReg, FI);
198   }
199 }
200 
201 namespace {
202 struct SlotWithTag {
203   int FI;
204   int Tag;
205   SlotWithTag(int FI, int Tag) : FI(FI), Tag(Tag) {}
206   explicit SlotWithTag(const MachineInstr &MI)
207       : FI(MI.getOperand(1).getIndex()), Tag(MI.getOperand(4).getImm()) {}
208   bool operator==(const SlotWithTag &Other) const {
209     return FI == Other.FI && Tag == Other.Tag;
210   }
211 };
212 } // namespace
213 
214 namespace llvm {
215 template <> struct DenseMapInfo<SlotWithTag> {
216   static inline SlotWithTag getEmptyKey() { return {-2, -2}; }
217   static inline SlotWithTag getTombstoneKey() { return {-3, -3}; }
218   static unsigned getHashValue(const SlotWithTag &V) {
219     return hash_combine(DenseMapInfo<int>::getHashValue(V.FI),
220                         DenseMapInfo<int>::getHashValue(V.Tag));
221   }
222   static bool isEqual(const SlotWithTag &A, const SlotWithTag &B) {
223     return A == B;
224   }
225 };
226 } // namespace llvm
227 
228 static bool isSlotPreAllocated(MachineFrameInfo *MFI, int FI) {
229   return MFI->getUseLocalStackAllocationBlock() &&
230          MFI->isObjectPreAllocated(FI);
231 }
232 
233 // Pin one of the tagged slots to offset 0 from the tagged base pointer.
234 // This would make its address available in a virtual register (IRG's def), as
235 // opposed to requiring an ADDG instruction to materialize. This effectively
236 // eliminates a vreg (by replacing it with direct uses of IRG, which is usually
237 // live almost everywhere anyway), and therefore needs to happen before
238 // regalloc.
239 std::optional<int> AArch64StackTaggingPreRA::findFirstSlotCandidate() {
240   // Find the best (FI, Tag) pair to pin to offset 0.
241   // Looking at the possible uses of a tagged address, the advantage of pinning
242   // is:
243   // - COPY to physical register.
244   //   Does not matter, this would trade a MOV instruction for an ADDG.
245   // - ST*G matter, but those mostly appear near the function prologue where all
246   //   the tagged addresses need to be materialized anyway; also, counting ST*G
247   //   uses would overweight large allocas that require more than one ST*G
248   //   instruction.
249   // - Load/Store instructions in the address operand do not require a tagged
250   //   pointer, so they also do not benefit. These operands have already been
251   //   eliminated (see uncheckLoadsAndStores) so all remaining load/store
252   //   instructions count.
253   // - Any other instruction may benefit from being pinned to offset 0.
254   LLVM_DEBUG(dbgs() << "AArch64StackTaggingPreRA::findFirstSlotCandidate\n");
255   if (!ClFirstSlot)
256     return std::nullopt;
257 
258   DenseMap<SlotWithTag, int> RetagScore;
259   SlotWithTag MaxScoreST{-1, -1};
260   int MaxScore = -1;
261   for (auto *I : ReTags) {
262     SlotWithTag ST{*I};
263     if (isSlotPreAllocated(MFI, ST.FI))
264       continue;
265 
266     Register RetagReg = I->getOperand(0).getReg();
267     if (!RetagReg.isVirtual())
268       continue;
269 
270     int Score = 0;
271     SmallVector<Register, 8> WorkList;
272     WorkList.push_back(RetagReg);
273 
274     while (!WorkList.empty()) {
275       Register UseReg = WorkList.pop_back_val();
276       for (auto &UseI : MRI->use_instructions(UseReg)) {
277         unsigned Opcode = UseI.getOpcode();
278         if (Opcode == AArch64::STGi || Opcode == AArch64::ST2Gi ||
279             Opcode == AArch64::STZGi || Opcode == AArch64::STZ2Gi ||
280             Opcode == AArch64::STGPi || Opcode == AArch64::STGloop ||
281             Opcode == AArch64::STZGloop || Opcode == AArch64::STGloop_wback ||
282             Opcode == AArch64::STZGloop_wback)
283           continue;
284         if (UseI.isCopy()) {
285           Register DstReg = UseI.getOperand(0).getReg();
286           if (DstReg.isVirtual())
287             WorkList.push_back(DstReg);
288           continue;
289         }
290         LLVM_DEBUG(dbgs() << "[" << ST.FI << ":" << ST.Tag << "] use of %"
291                           << Register::virtReg2Index(UseReg) << " in " << UseI
292                           << "\n");
293         Score++;
294       }
295     }
296 
297     int TotalScore = RetagScore[ST] += Score;
298     if (TotalScore > MaxScore ||
299         (TotalScore == MaxScore && ST.FI > MaxScoreST.FI)) {
300       MaxScore = TotalScore;
301       MaxScoreST = ST;
302     }
303   }
304 
305   if (MaxScoreST.FI < 0)
306     return std::nullopt;
307 
308   // If FI's tag is already 0, we are done.
309   if (MaxScoreST.Tag == 0)
310     return MaxScoreST.FI;
311 
312   // Otherwise, find a random victim pair (FI, Tag) where Tag == 0.
313   SlotWithTag SwapST{-1, -1};
314   for (auto *I : ReTags) {
315     SlotWithTag ST{*I};
316     if (ST.Tag == 0) {
317       SwapST = ST;
318       break;
319     }
320   }
321 
322   // Swap tags between the victim and the highest scoring pair.
323   // If SwapWith is still (-1, -1), that's fine, too - we'll simply take tag for
324   // the highest score slot without changing anything else.
325   for (auto *&I : ReTags) {
326     SlotWithTag ST{*I};
327     MachineOperand &TagOp = I->getOperand(4);
328     if (ST == MaxScoreST) {
329       TagOp.setImm(0);
330     } else if (ST == SwapST) {
331       TagOp.setImm(MaxScoreST.Tag);
332     }
333   }
334   return MaxScoreST.FI;
335 }
336 
337 bool AArch64StackTaggingPreRA::runOnMachineFunction(MachineFunction &Func) {
338   MF = &Func;
339   MRI = &MF->getRegInfo();
340   AFI = MF->getInfo<AArch64FunctionInfo>();
341   TII = static_cast<const AArch64InstrInfo *>(MF->getSubtarget().getInstrInfo());
342   TRI = static_cast<const AArch64RegisterInfo *>(
343       MF->getSubtarget().getRegisterInfo());
344   MFI = &MF->getFrameInfo();
345   ReTags.clear();
346 
347   assert(MRI->isSSA());
348 
349   LLVM_DEBUG(dbgs() << "********** AArch64 Stack Tagging PreRA **********\n"
350                     << "********** Function: " << MF->getName() << '\n');
351 
352   SmallSetVector<int, 8> TaggedSlots;
353   for (auto &BB : *MF) {
354     for (auto &I : BB) {
355       if (I.getOpcode() == AArch64::TAGPstack) {
356         ReTags.push_back(&I);
357         int FI = I.getOperand(1).getIndex();
358         TaggedSlots.insert(FI);
359         // There should be no offsets in TAGP yet.
360         assert(I.getOperand(2).getImm() == 0);
361       }
362     }
363   }
364 
365   // Take over from SSP. It does nothing for tagged slots, and should not really
366   // have been enabled in the first place.
367   for (int FI : TaggedSlots)
368     MFI->setObjectSSPLayout(FI, MachineFrameInfo::SSPLK_None);
369 
370   if (ReTags.empty())
371     return false;
372 
373   if (mayUseUncheckedLoadStore())
374     uncheckLoadsAndStores();
375 
376   // Find a slot that is used with zero tag offset, like ADDG #fi, 0.
377   // If the base tagged pointer is set up to the address of this slot,
378   // the ADDG instruction can be eliminated.
379   std::optional<int> BaseSlot = findFirstSlotCandidate();
380   if (BaseSlot)
381     AFI->setTaggedBasePointerIndex(*BaseSlot);
382 
383   for (auto *I : ReTags) {
384     int FI = I->getOperand(1).getIndex();
385     int Tag = I->getOperand(4).getImm();
386     Register Base = I->getOperand(3).getReg();
387     if (Tag == 0 && FI == BaseSlot) {
388       BuildMI(*I->getParent(), I, {}, TII->get(AArch64::COPY),
389               I->getOperand(0).getReg())
390           .addReg(Base);
391       I->eraseFromParent();
392     }
393   }
394 
395   return true;
396 }
397