1 //==-- AArch64CompressJumpTables.cpp - Compress jump tables for AArch64 --====//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // This pass looks at the basic blocks each jump-table refers to and works out
8 // whether they can be emitted in a compressed form (with 8 or 16-bit
9 // entries). If so, it changes the opcode and flags them in the associated
10 // AArch64FunctionInfo.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AArch64.h"
15 #include "AArch64MachineFunctionInfo.h"
16 #include "AArch64Subtarget.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/CodeGen/MachineFunctionPass.h"
19 #include "llvm/CodeGen/MachineJumpTableInfo.h"
20 #include "llvm/CodeGen/TargetInstrInfo.h"
21 #include "llvm/CodeGen/TargetSubtargetInfo.h"
22 #include "llvm/MC/MCContext.h"
23 #include "llvm/Support/Alignment.h"
24 #include "llvm/Support/Debug.h"
25 
26 using namespace llvm;
27 
28 #define DEBUG_TYPE "aarch64-jump-tables"
29 
30 STATISTIC(NumJT8, "Number of jump-tables with 1-byte entries");
31 STATISTIC(NumJT16, "Number of jump-tables with 2-byte entries");
32 STATISTIC(NumJT32, "Number of jump-tables with 4-byte entries");
33 
34 namespace {
35 class AArch64CompressJumpTables : public MachineFunctionPass {
36   const TargetInstrInfo *TII;
37   MachineFunction *MF;
38   SmallVector<int, 8> BlockInfo;
39 
40   /// Returns the size in instructions of the block \p MBB, or None if we
41   /// couldn't get a safe upper bound.
42   Optional<int> computeBlockSize(MachineBasicBlock &MBB);
43 
44   /// Gather information about the function, returns false if we can't perform
45   /// this optimization for some reason.
46   bool scanFunction();
47 
48   bool compressJumpTable(MachineInstr &MI, int Offset);
49 
50 public:
51   static char ID;
AArch64CompressJumpTables()52   AArch64CompressJumpTables() : MachineFunctionPass(ID) {
53     initializeAArch64CompressJumpTablesPass(*PassRegistry::getPassRegistry());
54   }
55 
56   bool runOnMachineFunction(MachineFunction &MF) override;
57 
getRequiredProperties() const58   MachineFunctionProperties getRequiredProperties() const override {
59     return MachineFunctionProperties().set(
60         MachineFunctionProperties::Property::NoVRegs);
61   }
getPassName() const62   StringRef getPassName() const override {
63     return "AArch64 Compress Jump Tables";
64   }
65 };
66 char AArch64CompressJumpTables::ID = 0;
67 } // namespace
68 
69 INITIALIZE_PASS(AArch64CompressJumpTables, DEBUG_TYPE,
70                 "AArch64 compress jump tables pass", false, false)
71 
72 Optional<int>
computeBlockSize(MachineBasicBlock & MBB)73 AArch64CompressJumpTables::computeBlockSize(MachineBasicBlock &MBB) {
74   int Size = 0;
75   for (const MachineInstr &MI : MBB) {
76     // Inline asm may contain some directives like .bytes which we don't
77     // currently have the ability to parse accurately. To be safe, just avoid
78     // computing a size and bail out.
79     if (MI.getOpcode() == AArch64::INLINEASM ||
80         MI.getOpcode() == AArch64::INLINEASM_BR)
81       return None;
82     Size += TII->getInstSizeInBytes(MI);
83   }
84   return Size;
85 }
86 
scanFunction()87 bool AArch64CompressJumpTables::scanFunction() {
88   BlockInfo.clear();
89   BlockInfo.resize(MF->getNumBlockIDs());
90 
91   unsigned Offset = 0;
92   for (MachineBasicBlock &MBB : *MF) {
93     const Align Alignment = MBB.getAlignment();
94     unsigned AlignedOffset;
95     if (Alignment == Align(1))
96       AlignedOffset = Offset;
97     else
98       AlignedOffset = alignTo(Offset, Alignment);
99     BlockInfo[MBB.getNumber()] = AlignedOffset;
100     auto BlockSize = computeBlockSize(MBB);
101     if (!BlockSize)
102       return false;
103     Offset = AlignedOffset + *BlockSize;
104   }
105   return true;
106 }
107 
compressJumpTable(MachineInstr & MI,int Offset)108 bool AArch64CompressJumpTables::compressJumpTable(MachineInstr &MI,
109                                                   int Offset) {
110   if (MI.getOpcode() != AArch64::JumpTableDest32)
111     return false;
112 
113   int JTIdx = MI.getOperand(4).getIndex();
114   auto &JTInfo = *MF->getJumpTableInfo();
115   const MachineJumpTableEntry &JT = JTInfo.getJumpTables()[JTIdx];
116 
117   // The jump-table might have been optimized away.
118   if (JT.MBBs.empty())
119     return false;
120 
121   int MaxOffset = std::numeric_limits<int>::min(),
122       MinOffset = std::numeric_limits<int>::max();
123   MachineBasicBlock *MinBlock = nullptr;
124   for (auto *Block : JT.MBBs) {
125     int BlockOffset = BlockInfo[Block->getNumber()];
126     assert(BlockOffset % 4 == 0 && "misaligned basic block");
127 
128     MaxOffset = std::max(MaxOffset, BlockOffset);
129     if (BlockOffset <= MinOffset) {
130       MinOffset = BlockOffset;
131       MinBlock = Block;
132     }
133   }
134   assert(MinBlock && "Failed to find minimum offset block");
135 
136   // The ADR instruction needed to calculate the address of the first reachable
137   // basic block can address +/-1MB.
138   if (!isInt<21>(MinOffset - Offset)) {
139     ++NumJT32;
140     return false;
141   }
142 
143   int Span = MaxOffset - MinOffset;
144   auto *AFI = MF->getInfo<AArch64FunctionInfo>();
145   if (isUInt<8>(Span / 4)) {
146     AFI->setJumpTableEntryInfo(JTIdx, 1, MinBlock->getSymbol());
147     MI.setDesc(TII->get(AArch64::JumpTableDest8));
148     ++NumJT8;
149     return true;
150   }
151   if (isUInt<16>(Span / 4)) {
152     AFI->setJumpTableEntryInfo(JTIdx, 2, MinBlock->getSymbol());
153     MI.setDesc(TII->get(AArch64::JumpTableDest16));
154     ++NumJT16;
155     return true;
156   }
157 
158   ++NumJT32;
159   return false;
160 }
161 
runOnMachineFunction(MachineFunction & MFIn)162 bool AArch64CompressJumpTables::runOnMachineFunction(MachineFunction &MFIn) {
163   bool Changed = false;
164   MF = &MFIn;
165 
166   const auto &ST = MF->getSubtarget<AArch64Subtarget>();
167   TII = ST.getInstrInfo();
168 
169   if (ST.force32BitJumpTables() && !MF->getFunction().hasMinSize())
170     return false;
171 
172   if (!scanFunction())
173     return false;
174 
175   for (MachineBasicBlock &MBB : *MF) {
176     int Offset = BlockInfo[MBB.getNumber()];
177     for (MachineInstr &MI : MBB) {
178       Changed |= compressJumpTable(MI, Offset);
179       Offset += TII->getInstSizeInBytes(MI);
180     }
181   }
182 
183   return Changed;
184 }
185 
createAArch64CompressJumpTablesPass()186 FunctionPass *llvm::createAArch64CompressJumpTablesPass() {
187   return new AArch64CompressJumpTables();
188 }
189