1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13
14 #include "GCNRegPressure.h"
15 #include "llvm/CodeGen/RegisterPressure.h"
16
17 using namespace llvm;
18
19 #define DEBUG_TYPE "machine-scheduler"
20
isEqual(const GCNRPTracker::LiveRegSet & S1,const GCNRPTracker::LiveRegSet & S2)21 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
22 const GCNRPTracker::LiveRegSet &S2) {
23 if (S1.size() != S2.size())
24 return false;
25
26 for (const auto &P : S1) {
27 auto I = S2.find(P.first);
28 if (I == S2.end() || I->second != P.second)
29 return false;
30 }
31 return true;
32 }
33
34
35 ///////////////////////////////////////////////////////////////////////////////
36 // GCNRegPressure
37
getRegKind(Register Reg,const MachineRegisterInfo & MRI)38 unsigned GCNRegPressure::getRegKind(Register Reg,
39 const MachineRegisterInfo &MRI) {
40 assert(Reg.isVirtual());
41 const auto RC = MRI.getRegClass(Reg);
42 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
43 return STI->isSGPRClass(RC)
44 ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE)
45 : STI->isAGPRClass(RC)
46 ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE)
47 : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
48 }
49
inc(unsigned Reg,LaneBitmask PrevMask,LaneBitmask NewMask,const MachineRegisterInfo & MRI)50 void GCNRegPressure::inc(unsigned Reg,
51 LaneBitmask PrevMask,
52 LaneBitmask NewMask,
53 const MachineRegisterInfo &MRI) {
54 if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
55 SIRegisterInfo::getNumCoveredRegs(PrevMask))
56 return;
57
58 int Sign = 1;
59 if (NewMask < PrevMask) {
60 std::swap(NewMask, PrevMask);
61 Sign = -1;
62 }
63
64 switch (auto Kind = getRegKind(Reg, MRI)) {
65 case SGPR32:
66 case VGPR32:
67 case AGPR32:
68 Value[Kind] += Sign;
69 break;
70
71 case SGPR_TUPLE:
72 case VGPR_TUPLE:
73 case AGPR_TUPLE:
74 assert(PrevMask < NewMask);
75
76 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
77 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
78
79 if (PrevMask.none()) {
80 assert(NewMask.any());
81 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
82 }
83 break;
84
85 default: llvm_unreachable("Unknown register kind");
86 }
87 }
88
less(const GCNSubtarget & ST,const GCNRegPressure & O,unsigned MaxOccupancy) const89 bool GCNRegPressure::less(const GCNSubtarget &ST,
90 const GCNRegPressure& O,
91 unsigned MaxOccupancy) const {
92 const auto SGPROcc = std::min(MaxOccupancy,
93 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
94 const auto VGPROcc =
95 std::min(MaxOccupancy,
96 ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
97 const auto OtherSGPROcc = std::min(MaxOccupancy,
98 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
99 const auto OtherVGPROcc =
100 std::min(MaxOccupancy,
101 ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts())));
102
103 const auto Occ = std::min(SGPROcc, VGPROcc);
104 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
105 if (Occ != OtherOcc)
106 return Occ > OtherOcc;
107
108 bool SGPRImportant = SGPROcc < VGPROcc;
109 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
110
111 // if both pressures disagree on what is more important compare vgprs
112 if (SGPRImportant != OtherSGPRImportant) {
113 SGPRImportant = false;
114 }
115
116 // compare large regs pressure
117 bool SGPRFirst = SGPRImportant;
118 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
119 if (SGPRFirst) {
120 auto SW = getSGPRTuplesWeight();
121 auto OtherSW = O.getSGPRTuplesWeight();
122 if (SW != OtherSW)
123 return SW < OtherSW;
124 } else {
125 auto VW = getVGPRTuplesWeight();
126 auto OtherVW = O.getVGPRTuplesWeight();
127 if (VW != OtherVW)
128 return VW < OtherVW;
129 }
130 }
131 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
132 (getVGPRNum(ST.hasGFX90AInsts()) <
133 O.getVGPRNum(ST.hasGFX90AInsts()));
134 }
135
136 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
137 LLVM_DUMP_METHOD
print(const GCNRegPressure & RP,const GCNSubtarget * ST)138 Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) {
139 return Printable([&RP, ST](raw_ostream &OS) {
140 OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' '
141 << "AGPRs: " << RP.getAGPRNum();
142 if (ST)
143 OS << "(O"
144 << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts()))
145 << ')';
146 OS << ", SGPRs: " << RP.getSGPRNum();
147 if (ST)
148 OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')';
149 OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight()
150 << ", LSGPR WT: " << RP.getSGPRTuplesWeight();
151 if (ST)
152 OS << " -> Occ: " << RP.getOccupancy(*ST);
153 OS << '\n';
154 });
155 }
156 #endif
157
getDefRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI)158 static LaneBitmask getDefRegMask(const MachineOperand &MO,
159 const MachineRegisterInfo &MRI) {
160 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
161
162 // We don't rely on read-undef flag because in case of tentative schedule
163 // tracking it isn't set correctly yet. This works correctly however since
164 // use mask has been tracked before using LIS.
165 return MO.getSubReg() == 0 ?
166 MRI.getMaxLaneMaskForVReg(MO.getReg()) :
167 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
168 }
169
getUsedRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI,const LiveIntervals & LIS)170 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
171 const MachineRegisterInfo &MRI,
172 const LiveIntervals &LIS) {
173 assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual());
174
175 if (auto SubReg = MO.getSubReg())
176 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
177
178 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
179 if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
180 return MaxMask;
181
182 // For a tentative schedule LIS isn't updated yet but livemask should remain
183 // the same on any schedule. Subreg defs can be reordered but they all must
184 // dominate uses anyway.
185 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
186 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
187 }
188
189 static SmallVector<RegisterMaskPair, 8>
collectVirtualRegUses(const MachineInstr & MI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)190 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
191 const MachineRegisterInfo &MRI) {
192 SmallVector<RegisterMaskPair, 8> Res;
193 for (const auto &MO : MI.operands()) {
194 if (!MO.isReg() || !MO.getReg().isVirtual())
195 continue;
196 if (!MO.isUse() || !MO.readsReg())
197 continue;
198
199 auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
200
201 auto Reg = MO.getReg();
202 auto I = llvm::find_if(
203 Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; });
204 if (I != Res.end())
205 I->LaneMask |= UsedMask;
206 else
207 Res.push_back(RegisterMaskPair(Reg, UsedMask));
208 }
209 return Res;
210 }
211
212 ///////////////////////////////////////////////////////////////////////////////
213 // GCNRPTracker
214
getLiveLaneMask(unsigned Reg,SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)215 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
216 SlotIndex SI,
217 const LiveIntervals &LIS,
218 const MachineRegisterInfo &MRI) {
219 LaneBitmask LiveMask;
220 const auto &LI = LIS.getInterval(Reg);
221 if (LI.hasSubRanges()) {
222 for (const auto &S : LI.subranges())
223 if (S.liveAt(SI)) {
224 LiveMask |= S.LaneMask;
225 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
226 LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
227 }
228 } else if (LI.liveAt(SI)) {
229 LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
230 }
231 return LiveMask;
232 }
233
getLiveRegs(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)234 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
235 const LiveIntervals &LIS,
236 const MachineRegisterInfo &MRI) {
237 GCNRPTracker::LiveRegSet LiveRegs;
238 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
239 auto Reg = Register::index2VirtReg(I);
240 if (!LIS.hasInterval(Reg))
241 continue;
242 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
243 if (LiveMask.any())
244 LiveRegs[Reg] = LiveMask;
245 }
246 return LiveRegs;
247 }
248
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy,bool After)249 void GCNRPTracker::reset(const MachineInstr &MI,
250 const LiveRegSet *LiveRegsCopy,
251 bool After) {
252 const MachineFunction &MF = *MI.getMF();
253 MRI = &MF.getRegInfo();
254 if (LiveRegsCopy) {
255 if (&LiveRegs != LiveRegsCopy)
256 LiveRegs = *LiveRegsCopy;
257 } else {
258 LiveRegs = After ? getLiveRegsAfter(MI, LIS)
259 : getLiveRegsBefore(MI, LIS);
260 }
261
262 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
263 }
264
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)265 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
266 const LiveRegSet *LiveRegsCopy) {
267 GCNRPTracker::reset(MI, LiveRegsCopy, true);
268 }
269
recede(const MachineInstr & MI)270 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
271 assert(MRI && "call reset first");
272
273 LastTrackedMI = &MI;
274
275 if (MI.isDebugInstr())
276 return;
277
278 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
279
280 // calc pressure at the MI (defs + uses)
281 auto AtMIPressure = CurPressure;
282 for (const auto &U : RegUses) {
283 auto LiveMask = LiveRegs[U.RegUnit];
284 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
285 }
286 // update max pressure
287 MaxPressure = max(AtMIPressure, MaxPressure);
288
289 for (const auto &MO : MI.operands()) {
290 if (!MO.isReg() || !MO.isDef() || !MO.getReg().isVirtual() || MO.isDead())
291 continue;
292
293 auto Reg = MO.getReg();
294 auto I = LiveRegs.find(Reg);
295 if (I == LiveRegs.end())
296 continue;
297 auto &LiveMask = I->second;
298 auto PrevMask = LiveMask;
299 LiveMask &= ~getDefRegMask(MO, *MRI);
300 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
301 if (LiveMask.none())
302 LiveRegs.erase(I);
303 }
304 for (const auto &U : RegUses) {
305 auto &LiveMask = LiveRegs[U.RegUnit];
306 auto PrevMask = LiveMask;
307 LiveMask |= U.LaneMask;
308 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
309 }
310 assert(CurPressure == getRegPressure(*MRI, LiveRegs));
311 }
312
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)313 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
314 const LiveRegSet *LiveRegsCopy) {
315 MRI = &MI.getParent()->getParent()->getRegInfo();
316 LastTrackedMI = nullptr;
317 MBBEnd = MI.getParent()->end();
318 NextMI = &MI;
319 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
320 if (NextMI == MBBEnd)
321 return false;
322 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
323 return true;
324 }
325
advanceBeforeNext()326 bool GCNDownwardRPTracker::advanceBeforeNext() {
327 assert(MRI && "call reset first");
328 if (!LastTrackedMI)
329 return NextMI == MBBEnd;
330
331 assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
332
333 SlotIndex SI = NextMI == MBBEnd
334 ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
335 : LIS.getInstructionIndex(*NextMI).getBaseIndex();
336 assert(SI.isValid());
337
338 // Remove dead registers or mask bits.
339 for (auto &It : LiveRegs) {
340 const LiveInterval &LI = LIS.getInterval(It.first);
341 if (LI.hasSubRanges()) {
342 for (const auto &S : LI.subranges()) {
343 if (!S.liveAt(SI)) {
344 auto PrevMask = It.second;
345 It.second &= ~S.LaneMask;
346 CurPressure.inc(It.first, PrevMask, It.second, *MRI);
347 }
348 }
349 } else if (!LI.liveAt(SI)) {
350 auto PrevMask = It.second;
351 It.second = LaneBitmask::getNone();
352 CurPressure.inc(It.first, PrevMask, It.second, *MRI);
353 }
354 if (It.second.none())
355 LiveRegs.erase(It.first);
356 }
357
358 MaxPressure = max(MaxPressure, CurPressure);
359
360 LastTrackedMI = nullptr;
361
362 return NextMI == MBBEnd;
363 }
364
advanceToNext()365 void GCNDownwardRPTracker::advanceToNext() {
366 LastTrackedMI = &*NextMI++;
367 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
368
369 // Add new registers or mask bits.
370 for (const auto &MO : LastTrackedMI->operands()) {
371 if (!MO.isReg() || !MO.isDef())
372 continue;
373 Register Reg = MO.getReg();
374 if (!Reg.isVirtual())
375 continue;
376 auto &LiveMask = LiveRegs[Reg];
377 auto PrevMask = LiveMask;
378 LiveMask |= getDefRegMask(MO, *MRI);
379 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
380 }
381
382 MaxPressure = max(MaxPressure, CurPressure);
383 }
384
advance()385 bool GCNDownwardRPTracker::advance() {
386 if (NextMI == MBBEnd)
387 return false;
388 advanceBeforeNext();
389 advanceToNext();
390 return true;
391 }
392
advance(MachineBasicBlock::const_iterator End)393 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
394 while (NextMI != End)
395 if (!advance()) return false;
396 return true;
397 }
398
advance(MachineBasicBlock::const_iterator Begin,MachineBasicBlock::const_iterator End,const LiveRegSet * LiveRegsCopy)399 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
400 MachineBasicBlock::const_iterator End,
401 const LiveRegSet *LiveRegsCopy) {
402 reset(*Begin, LiveRegsCopy);
403 return advance(End);
404 }
405
406 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
407 LLVM_DUMP_METHOD
reportMismatch(const GCNRPTracker::LiveRegSet & LISLR,const GCNRPTracker::LiveRegSet & TrackedLR,const TargetRegisterInfo * TRI)408 Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
409 const GCNRPTracker::LiveRegSet &TrackedLR,
410 const TargetRegisterInfo *TRI) {
411 return Printable([&LISLR, &TrackedLR, TRI](raw_ostream &OS) {
412 for (auto const &P : TrackedLR) {
413 auto I = LISLR.find(P.first);
414 if (I == LISLR.end()) {
415 OS << " " << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
416 << " isn't found in LIS reported set\n";
417 } else if (I->second != P.second) {
418 OS << " " << printReg(P.first, TRI)
419 << " masks doesn't match: LIS reported " << PrintLaneMask(I->second)
420 << ", tracked " << PrintLaneMask(P.second) << '\n';
421 }
422 }
423 for (auto const &P : LISLR) {
424 auto I = TrackedLR.find(P.first);
425 if (I == TrackedLR.end()) {
426 OS << " " << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
427 << " isn't found in tracked set\n";
428 }
429 }
430 });
431 }
432
isValid() const433 bool GCNUpwardRPTracker::isValid() const {
434 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
435 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
436 const auto &TrackedLR = LiveRegs;
437
438 if (!isEqual(LISLR, TrackedLR)) {
439 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
440 " LIS reported livesets mismatch:\n"
441 << print(LISLR, *MRI);
442 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
443 return false;
444 }
445
446 auto LISPressure = getRegPressure(*MRI, LISLR);
447 if (LISPressure != CurPressure) {
448 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "
449 << print(CurPressure) << "LIS rpt: " << print(LISPressure);
450 return false;
451 }
452 return true;
453 }
454
455 LLVM_DUMP_METHOD
print(const GCNRPTracker::LiveRegSet & LiveRegs,const MachineRegisterInfo & MRI)456 Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
457 const MachineRegisterInfo &MRI) {
458 return Printable([&LiveRegs, &MRI](raw_ostream &OS) {
459 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
460 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
461 Register Reg = Register::index2VirtReg(I);
462 auto It = LiveRegs.find(Reg);
463 if (It != LiveRegs.end() && It->second.any())
464 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
465 << PrintLaneMask(It->second);
466 }
467 OS << '\n';
468 });
469 }
470
471 LLVM_DUMP_METHOD
dump() const472 void GCNRegPressure::dump() const { dbgs() << print(*this); }
473
474 #endif
475