1 //===- GCNRegPressure.h -----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file defines the GCNRegPressure class, which tracks registry pressure
11 /// by bookkeeping number of SGPR/VGPRs used, weights for large SGPR/VGPRs. It
12 /// also implements a compare function, which compares different register
13 /// pressures, and declares one with max occupancy as winner.
14 ///
15 //===----------------------------------------------------------------------===//
16
17 #ifndef LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
18 #define LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
19
20 #include "GCNSubtarget.h"
21 #include "llvm/CodeGen/LiveIntervals.h"
22 #include <algorithm>
23
24 namespace llvm {
25
26 class MachineRegisterInfo;
27 class raw_ostream;
28 class SlotIndex;
29
30 struct GCNRegPressure {
31 enum RegKind {
32 SGPR32,
33 SGPR_TUPLE,
34 VGPR32,
35 VGPR_TUPLE,
36 AGPR32,
37 AGPR_TUPLE,
38 TOTAL_KINDS
39 };
40
GCNRegPressureGCNRegPressure41 GCNRegPressure() {
42 clear();
43 }
44
emptyGCNRegPressure45 bool empty() const { return getSGPRNum() == 0 && getVGPRNum(false) == 0; }
46
clearGCNRegPressure47 void clear() { std::fill(&Value[0], &Value[TOTAL_KINDS], 0); }
48
getSGPRNumGCNRegPressure49 unsigned getSGPRNum() const { return Value[SGPR32]; }
getVGPRNumGCNRegPressure50 unsigned getVGPRNum(bool UnifiedVGPRFile) const {
51 if (UnifiedVGPRFile) {
52 return Value[AGPR32] ? alignTo(Value[VGPR32], 4) + Value[AGPR32]
53 : Value[VGPR32] + Value[AGPR32];
54 }
55 return std::max(Value[VGPR32], Value[AGPR32]);
56 }
getAGPRNumGCNRegPressure57 unsigned getAGPRNum() const { return Value[AGPR32]; }
58
getVGPRTuplesWeightGCNRegPressure59 unsigned getVGPRTuplesWeight() const { return std::max(Value[VGPR_TUPLE],
60 Value[AGPR_TUPLE]); }
getSGPRTuplesWeightGCNRegPressure61 unsigned getSGPRTuplesWeight() const { return Value[SGPR_TUPLE]; }
62
getOccupancyGCNRegPressure63 unsigned getOccupancy(const GCNSubtarget &ST) const {
64 return std::min(ST.getOccupancyWithNumSGPRs(getSGPRNum()),
65 ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
66 }
67
68 void inc(unsigned Reg,
69 LaneBitmask PrevMask,
70 LaneBitmask NewMask,
71 const MachineRegisterInfo &MRI);
72
higherOccupancyGCNRegPressure73 bool higherOccupancy(const GCNSubtarget &ST, const GCNRegPressure& O) const {
74 return getOccupancy(ST) > O.getOccupancy(ST);
75 }
76
77 bool less(const GCNSubtarget &ST, const GCNRegPressure& O,
78 unsigned MaxOccupancy = std::numeric_limits<unsigned>::max()) const;
79
80 bool operator==(const GCNRegPressure &O) const {
81 return std::equal(&Value[0], &Value[TOTAL_KINDS], O.Value);
82 }
83
84 bool operator!=(const GCNRegPressure &O) const {
85 return !(*this == O);
86 }
87
88 GCNRegPressure &operator+=(const GCNRegPressure &RHS) {
89 for (unsigned I = 0; I < TOTAL_KINDS; ++I)
90 Value[I] += RHS.Value[I];
91 return *this;
92 }
93
94 GCNRegPressure &operator-=(const GCNRegPressure &RHS) {
95 for (unsigned I = 0; I < TOTAL_KINDS; ++I)
96 Value[I] -= RHS.Value[I];
97 return *this;
98 }
99
100 void dump() const;
101
102 private:
103 unsigned Value[TOTAL_KINDS];
104
105 static unsigned getRegKind(Register Reg, const MachineRegisterInfo &MRI);
106
107 friend GCNRegPressure max(const GCNRegPressure &P1,
108 const GCNRegPressure &P2);
109
110 friend Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST);
111 };
112
max(const GCNRegPressure & P1,const GCNRegPressure & P2)113 inline GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2) {
114 GCNRegPressure Res;
115 for (unsigned I = 0; I < GCNRegPressure::TOTAL_KINDS; ++I)
116 Res.Value[I] = std::max(P1.Value[I], P2.Value[I]);
117 return Res;
118 }
119
120 inline GCNRegPressure operator+(const GCNRegPressure &P1,
121 const GCNRegPressure &P2) {
122 GCNRegPressure Sum = P1;
123 Sum += P2;
124 return Sum;
125 }
126
127 inline GCNRegPressure operator-(const GCNRegPressure &P1,
128 const GCNRegPressure &P2) {
129 GCNRegPressure Diff = P1;
130 Diff -= P2;
131 return Diff;
132 }
133
134 class GCNRPTracker {
135 public:
136 using LiveRegSet = DenseMap<unsigned, LaneBitmask>;
137
138 protected:
139 const LiveIntervals &LIS;
140 LiveRegSet LiveRegs;
141 GCNRegPressure CurPressure, MaxPressure;
142 const MachineInstr *LastTrackedMI = nullptr;
143 mutable const MachineRegisterInfo *MRI = nullptr;
144
GCNRPTracker(const LiveIntervals & LIS_)145 GCNRPTracker(const LiveIntervals &LIS_) : LIS(LIS_) {}
146
147 void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy,
148 bool After);
149
150 public:
151 // live regs for the current state
decltype(LiveRegs)152 const decltype(LiveRegs) &getLiveRegs() const { return LiveRegs; }
getLastTrackedMI()153 const MachineInstr *getLastTrackedMI() const { return LastTrackedMI; }
154
clearMaxPressure()155 void clearMaxPressure() { MaxPressure.clear(); }
156
getPressure()157 GCNRegPressure getPressure() const { return CurPressure; }
158
moveLiveRegs()159 decltype(LiveRegs) moveLiveRegs() {
160 return std::move(LiveRegs);
161 }
162 };
163
164 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
165 const MachineRegisterInfo &MRI);
166
167 class GCNUpwardRPTracker : public GCNRPTracker {
168 public:
GCNUpwardRPTracker(const LiveIntervals & LIS_)169 GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
170
171 // reset tracker and set live register set to the specified value.
172 void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_);
173
174 // reset tracker at the specified slot index.
reset(const MachineRegisterInfo & MRI,SlotIndex SI)175 void reset(const MachineRegisterInfo &MRI, SlotIndex SI) {
176 reset(MRI, llvm::getLiveRegs(SI, LIS, MRI));
177 }
178
179 // reset tracker to the end of the MBB.
reset(const MachineBasicBlock & MBB)180 void reset(const MachineBasicBlock &MBB) {
181 reset(MBB.getParent()->getRegInfo(),
182 LIS.getSlotIndexes()->getMBBEndIdx(&MBB));
183 }
184
185 // reset tracker to the point just after MI (in program order).
reset(const MachineInstr & MI)186 void reset(const MachineInstr &MI) {
187 reset(MI.getMF()->getRegInfo(), LIS.getInstructionIndex(MI).getDeadSlot());
188 }
189
190 // move to the state just before the MI (in program order).
191 void recede(const MachineInstr &MI);
192
193 // checks whether the tracker's state after receding MI corresponds
194 // to reported by LIS.
195 bool isValid() const;
196
getMaxPressure()197 const GCNRegPressure &getMaxPressure() const { return MaxPressure; }
198
resetMaxPressure()199 void resetMaxPressure() { MaxPressure = CurPressure; }
200
getMaxPressureAndReset()201 GCNRegPressure getMaxPressureAndReset() {
202 GCNRegPressure RP = MaxPressure;
203 resetMaxPressure();
204 return RP;
205 }
206 };
207
208 class GCNDownwardRPTracker : public GCNRPTracker {
209 // Last position of reset or advanceBeforeNext
210 MachineBasicBlock::const_iterator NextMI;
211
212 MachineBasicBlock::const_iterator MBBEnd;
213
214 public:
GCNDownwardRPTracker(const LiveIntervals & LIS_)215 GCNDownwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
216
getNext()217 MachineBasicBlock::const_iterator getNext() const { return NextMI; }
218
219 // Return MaxPressure and clear it.
moveMaxPressure()220 GCNRegPressure moveMaxPressure() {
221 auto Res = MaxPressure;
222 MaxPressure.clear();
223 return Res;
224 }
225
226 // Reset tracker to the point before the MI
227 // filling live regs upon this point using LIS.
228 // Returns false if block is empty except debug values.
229 bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr);
230
231 // Move to the state right before the next MI or after the end of MBB.
232 // Returns false if reached end of the block.
233 bool advanceBeforeNext();
234
235 // Move to the state at the MI, advanceBeforeNext has to be called first.
236 void advanceToNext();
237
238 // Move to the state at the next MI. Returns false if reached end of block.
239 bool advance();
240
241 // Advance instructions until before End.
242 bool advance(MachineBasicBlock::const_iterator End);
243
244 // Reset to Begin and advance to End.
245 bool advance(MachineBasicBlock::const_iterator Begin,
246 MachineBasicBlock::const_iterator End,
247 const LiveRegSet *LiveRegsCopy = nullptr);
248 };
249
250 LaneBitmask getLiveLaneMask(unsigned Reg,
251 SlotIndex SI,
252 const LiveIntervals &LIS,
253 const MachineRegisterInfo &MRI);
254
255 LaneBitmask getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
256 const MachineRegisterInfo &MRI);
257
258 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
259 const MachineRegisterInfo &MRI);
260
261 /// creates a map MachineInstr -> LiveRegSet
262 /// R - range of iterators on instructions
263 /// After - upon entry or exit of every instruction
264 /// Note: there is no entry in the map for instructions with empty live reg set
265 /// Complexity = O(NumVirtRegs * averageLiveRangeSegmentsPerReg * lg(R))
266 template <typename Range>
267 DenseMap<MachineInstr*, GCNRPTracker::LiveRegSet>
getLiveRegMap(Range && R,bool After,LiveIntervals & LIS)268 getLiveRegMap(Range &&R, bool After, LiveIntervals &LIS) {
269 std::vector<SlotIndex> Indexes;
270 Indexes.reserve(std::distance(R.begin(), R.end()));
271 auto &SII = *LIS.getSlotIndexes();
272 for (MachineInstr *I : R) {
273 auto SI = SII.getInstructionIndex(*I);
274 Indexes.push_back(After ? SI.getDeadSlot() : SI.getBaseIndex());
275 }
276 llvm::sort(Indexes);
277
278 auto &MRI = (*R.begin())->getParent()->getParent()->getRegInfo();
279 DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> LiveRegMap;
280 SmallVector<SlotIndex, 32> LiveIdxs, SRLiveIdxs;
281 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
282 auto Reg = Register::index2VirtReg(I);
283 if (!LIS.hasInterval(Reg))
284 continue;
285 auto &LI = LIS.getInterval(Reg);
286 LiveIdxs.clear();
287 if (!LI.findIndexesLiveAt(Indexes, std::back_inserter(LiveIdxs)))
288 continue;
289 if (!LI.hasSubRanges()) {
290 for (auto SI : LiveIdxs)
291 LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] =
292 MRI.getMaxLaneMaskForVReg(Reg);
293 } else
294 for (const auto &S : LI.subranges()) {
295 // constrain search for subranges by indexes live at main range
296 SRLiveIdxs.clear();
297 S.findIndexesLiveAt(LiveIdxs, std::back_inserter(SRLiveIdxs));
298 for (auto SI : SRLiveIdxs)
299 LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] |= S.LaneMask;
300 }
301 }
302 return LiveRegMap;
303 }
304
getLiveRegsAfter(const MachineInstr & MI,const LiveIntervals & LIS)305 inline GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI,
306 const LiveIntervals &LIS) {
307 return getLiveRegs(LIS.getInstructionIndex(MI).getDeadSlot(), LIS,
308 MI.getParent()->getParent()->getRegInfo());
309 }
310
getLiveRegsBefore(const MachineInstr & MI,const LiveIntervals & LIS)311 inline GCNRPTracker::LiveRegSet getLiveRegsBefore(const MachineInstr &MI,
312 const LiveIntervals &LIS) {
313 return getLiveRegs(LIS.getInstructionIndex(MI).getBaseIndex(), LIS,
314 MI.getParent()->getParent()->getRegInfo());
315 }
316
317 template <typename Range>
getRegPressure(const MachineRegisterInfo & MRI,Range && LiveRegs)318 GCNRegPressure getRegPressure(const MachineRegisterInfo &MRI,
319 Range &&LiveRegs) {
320 GCNRegPressure Res;
321 for (const auto &RM : LiveRegs)
322 Res.inc(RM.first, LaneBitmask::getNone(), RM.second, MRI);
323 return Res;
324 }
325
326 bool isEqual(const GCNRPTracker::LiveRegSet &S1,
327 const GCNRPTracker::LiveRegSet &S2);
328
329 Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST = nullptr);
330
331 Printable print(const GCNRPTracker::LiveRegSet &LiveRegs,
332 const MachineRegisterInfo &MRI);
333
334 Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
335 const GCNRPTracker::LiveRegSet &TrackedL,
336 const TargetRegisterInfo *TRI, StringRef Pfx = " ");
337
338 struct GCNRegPressurePrinter : public MachineFunctionPass {
339 static char ID;
340
341 public:
GCNRegPressurePrinterGCNRegPressurePrinter342 GCNRegPressurePrinter() : MachineFunctionPass(ID) {}
343
344 bool runOnMachineFunction(MachineFunction &MF) override;
345
getAnalysisUsageGCNRegPressurePrinter346 void getAnalysisUsage(AnalysisUsage &AU) const override {
347 AU.addRequired<LiveIntervals>();
348 AU.setPreservesAll();
349 MachineFunctionPass::getAnalysisUsage(AU);
350 }
351 };
352
353 } // end namespace llvm
354
355 #endif // LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
356