1 //===- GCNRegPressure.h -----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file defines the GCNRegPressure class, which tracks registry pressure
11 /// by bookkeeping number of SGPR/VGPRs used, weights for large SGPR/VGPRs. It
12 /// also implements a compare function, which compares different register
13 /// pressures, and declares one with max occupancy as winner.
14 ///
15 //===----------------------------------------------------------------------===//
16
17 #ifndef LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
18 #define LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
19
20 #include "GCNSubtarget.h"
21 #include "llvm/CodeGen/LiveIntervals.h"
22 #include <algorithm>
23
24 namespace llvm {
25
26 class MachineRegisterInfo;
27 class raw_ostream;
28 class SlotIndex;
29
30 struct GCNRegPressure {
31 enum RegKind {
32 SGPR32,
33 SGPR_TUPLE,
34 VGPR32,
35 VGPR_TUPLE,
36 AGPR32,
37 AGPR_TUPLE,
38 TOTAL_KINDS
39 };
40
GCNRegPressureGCNRegPressure41 GCNRegPressure() {
42 clear();
43 }
44
emptyGCNRegPressure45 bool empty() const { return getSGPRNum() == 0 && getVGPRNum(false) == 0; }
46
clearGCNRegPressure47 void clear() { std::fill(&Value[0], &Value[TOTAL_KINDS], 0); }
48
getSGPRNumGCNRegPressure49 unsigned getSGPRNum() const { return Value[SGPR32]; }
getVGPRNumGCNRegPressure50 unsigned getVGPRNum(bool UnifiedVGPRFile) const {
51 if (UnifiedVGPRFile) {
52 return Value[AGPR32] ? alignTo(Value[VGPR32], 4) + Value[AGPR32]
53 : Value[VGPR32] + Value[AGPR32];
54 }
55 return std::max(Value[VGPR32], Value[AGPR32]);
56 }
getAGPRNumGCNRegPressure57 unsigned getAGPRNum() const { return Value[AGPR32]; }
58
getVGPRTuplesWeightGCNRegPressure59 unsigned getVGPRTuplesWeight() const { return std::max(Value[VGPR_TUPLE],
60 Value[AGPR_TUPLE]); }
getSGPRTuplesWeightGCNRegPressure61 unsigned getSGPRTuplesWeight() const { return Value[SGPR_TUPLE]; }
62
getOccupancyGCNRegPressure63 unsigned getOccupancy(const GCNSubtarget &ST) const {
64 return std::min(ST.getOccupancyWithNumSGPRs(getSGPRNum()),
65 ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
66 }
67
68 void inc(unsigned Reg,
69 LaneBitmask PrevMask,
70 LaneBitmask NewMask,
71 const MachineRegisterInfo &MRI);
72
higherOccupancyGCNRegPressure73 bool higherOccupancy(const GCNSubtarget &ST, const GCNRegPressure& O) const {
74 return getOccupancy(ST) > O.getOccupancy(ST);
75 }
76
77 bool less(const GCNSubtarget &ST, const GCNRegPressure& O,
78 unsigned MaxOccupancy = std::numeric_limits<unsigned>::max()) const;
79
80 bool operator==(const GCNRegPressure &O) const {
81 return std::equal(&Value[0], &Value[TOTAL_KINDS], O.Value);
82 }
83
84 bool operator!=(const GCNRegPressure &O) const {
85 return !(*this == O);
86 }
87
88 void dump() const;
89
90 private:
91 unsigned Value[TOTAL_KINDS];
92
93 static unsigned getRegKind(Register Reg, const MachineRegisterInfo &MRI);
94
95 friend GCNRegPressure max(const GCNRegPressure &P1,
96 const GCNRegPressure &P2);
97
98 friend Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST);
99 };
100
max(const GCNRegPressure & P1,const GCNRegPressure & P2)101 inline GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2) {
102 GCNRegPressure Res;
103 for (unsigned I = 0; I < GCNRegPressure::TOTAL_KINDS; ++I)
104 Res.Value[I] = std::max(P1.Value[I], P2.Value[I]);
105 return Res;
106 }
107
108 class GCNRPTracker {
109 public:
110 using LiveRegSet = DenseMap<unsigned, LaneBitmask>;
111
112 protected:
113 const LiveIntervals &LIS;
114 LiveRegSet LiveRegs;
115 GCNRegPressure CurPressure, MaxPressure;
116 const MachineInstr *LastTrackedMI = nullptr;
117 mutable const MachineRegisterInfo *MRI = nullptr;
118
GCNRPTracker(const LiveIntervals & LIS_)119 GCNRPTracker(const LiveIntervals &LIS_) : LIS(LIS_) {}
120
121 void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy,
122 bool After);
123
124 public:
125 // live regs for the current state
decltype(LiveRegs)126 const decltype(LiveRegs) &getLiveRegs() const { return LiveRegs; }
getLastTrackedMI()127 const MachineInstr *getLastTrackedMI() const { return LastTrackedMI; }
128
clearMaxPressure()129 void clearMaxPressure() { MaxPressure.clear(); }
130
131 // returns MaxPressure, resetting it
moveMaxPressure()132 decltype(MaxPressure) moveMaxPressure() {
133 auto Res = MaxPressure;
134 MaxPressure.clear();
135 return Res;
136 }
137
moveLiveRegs()138 decltype(LiveRegs) moveLiveRegs() {
139 return std::move(LiveRegs);
140 }
141 };
142
143 class GCNUpwardRPTracker : public GCNRPTracker {
144 public:
GCNUpwardRPTracker(const LiveIntervals & LIS_)145 GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
146
147 // reset tracker to the point just below MI
148 // filling live regs upon this point using LIS
149 void reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr);
150
151 // move to the state just above the MI
152 void recede(const MachineInstr &MI);
153
154 // checks whether the tracker's state after receding MI corresponds
155 // to reported by LIS
156 bool isValid() const;
157 };
158
159 class GCNDownwardRPTracker : public GCNRPTracker {
160 // Last position of reset or advanceBeforeNext
161 MachineBasicBlock::const_iterator NextMI;
162
163 MachineBasicBlock::const_iterator MBBEnd;
164
165 public:
GCNDownwardRPTracker(const LiveIntervals & LIS_)166 GCNDownwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
167
getNext()168 MachineBasicBlock::const_iterator getNext() const { return NextMI; }
169
170 // Reset tracker to the point before the MI
171 // filling live regs upon this point using LIS.
172 // Returns false if block is empty except debug values.
173 bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr);
174
175 // Move to the state right before the next MI or after the end of MBB.
176 // Returns false if reached end of the block.
177 bool advanceBeforeNext();
178
179 // Move to the state at the MI, advanceBeforeNext has to be called first.
180 void advanceToNext();
181
182 // Move to the state at the next MI. Returns false if reached end of block.
183 bool advance();
184
185 // Advance instructions until before End.
186 bool advance(MachineBasicBlock::const_iterator End);
187
188 // Reset to Begin and advance to End.
189 bool advance(MachineBasicBlock::const_iterator Begin,
190 MachineBasicBlock::const_iterator End,
191 const LiveRegSet *LiveRegsCopy = nullptr);
192 };
193
194 LaneBitmask getLiveLaneMask(unsigned Reg,
195 SlotIndex SI,
196 const LiveIntervals &LIS,
197 const MachineRegisterInfo &MRI);
198
199 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI,
200 const LiveIntervals &LIS,
201 const MachineRegisterInfo &MRI);
202
203 /// creates a map MachineInstr -> LiveRegSet
204 /// R - range of iterators on instructions
205 /// After - upon entry or exit of every instruction
206 /// Note: there is no entry in the map for instructions with empty live reg set
207 /// Complexity = O(NumVirtRegs * averageLiveRangeSegmentsPerReg * lg(R))
208 template <typename Range>
209 DenseMap<MachineInstr*, GCNRPTracker::LiveRegSet>
getLiveRegMap(Range && R,bool After,LiveIntervals & LIS)210 getLiveRegMap(Range &&R, bool After, LiveIntervals &LIS) {
211 std::vector<SlotIndex> Indexes;
212 Indexes.reserve(std::distance(R.begin(), R.end()));
213 auto &SII = *LIS.getSlotIndexes();
214 for (MachineInstr *I : R) {
215 auto SI = SII.getInstructionIndex(*I);
216 Indexes.push_back(After ? SI.getDeadSlot() : SI.getBaseIndex());
217 }
218 llvm::sort(Indexes);
219
220 auto &MRI = (*R.begin())->getParent()->getParent()->getRegInfo();
221 DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> LiveRegMap;
222 SmallVector<SlotIndex, 32> LiveIdxs, SRLiveIdxs;
223 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
224 auto Reg = Register::index2VirtReg(I);
225 if (!LIS.hasInterval(Reg))
226 continue;
227 auto &LI = LIS.getInterval(Reg);
228 LiveIdxs.clear();
229 if (!LI.findIndexesLiveAt(Indexes, std::back_inserter(LiveIdxs)))
230 continue;
231 if (!LI.hasSubRanges()) {
232 for (auto SI : LiveIdxs)
233 LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] =
234 MRI.getMaxLaneMaskForVReg(Reg);
235 } else
236 for (const auto &S : LI.subranges()) {
237 // constrain search for subranges by indexes live at main range
238 SRLiveIdxs.clear();
239 S.findIndexesLiveAt(LiveIdxs, std::back_inserter(SRLiveIdxs));
240 for (auto SI : SRLiveIdxs)
241 LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] |= S.LaneMask;
242 }
243 }
244 return LiveRegMap;
245 }
246
getLiveRegsAfter(const MachineInstr & MI,const LiveIntervals & LIS)247 inline GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI,
248 const LiveIntervals &LIS) {
249 return getLiveRegs(LIS.getInstructionIndex(MI).getDeadSlot(), LIS,
250 MI.getParent()->getParent()->getRegInfo());
251 }
252
getLiveRegsBefore(const MachineInstr & MI,const LiveIntervals & LIS)253 inline GCNRPTracker::LiveRegSet getLiveRegsBefore(const MachineInstr &MI,
254 const LiveIntervals &LIS) {
255 return getLiveRegs(LIS.getInstructionIndex(MI).getBaseIndex(), LIS,
256 MI.getParent()->getParent()->getRegInfo());
257 }
258
259 template <typename Range>
getRegPressure(const MachineRegisterInfo & MRI,Range && LiveRegs)260 GCNRegPressure getRegPressure(const MachineRegisterInfo &MRI,
261 Range &&LiveRegs) {
262 GCNRegPressure Res;
263 for (const auto &RM : LiveRegs)
264 Res.inc(RM.first, LaneBitmask::getNone(), RM.second, MRI);
265 return Res;
266 }
267
268 bool isEqual(const GCNRPTracker::LiveRegSet &S1,
269 const GCNRPTracker::LiveRegSet &S2);
270
271 Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST = nullptr);
272
273 Printable print(const GCNRPTracker::LiveRegSet &LiveRegs,
274 const MachineRegisterInfo &MRI);
275
276 Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
277 const GCNRPTracker::LiveRegSet &TrackedL,
278 const TargetRegisterInfo *TRI);
279
280 } // end namespace llvm
281
282 #endif // LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
283