1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13
14 #include "GCNRegPressure.h"
15 #include "llvm/CodeGen/RegisterPressure.h"
16
17 using namespace llvm;
18
19 #define DEBUG_TYPE "machine-scheduler"
20
21 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
22 LLVM_DUMP_METHOD
printLivesAt(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)23 void llvm::printLivesAt(SlotIndex SI,
24 const LiveIntervals &LIS,
25 const MachineRegisterInfo &MRI) {
26 dbgs() << "Live regs at " << SI << ": "
27 << *LIS.getInstructionFromIndex(SI);
28 unsigned Num = 0;
29 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
30 const unsigned Reg = Register::index2VirtReg(I);
31 if (!LIS.hasInterval(Reg))
32 continue;
33 const auto &LI = LIS.getInterval(Reg);
34 if (LI.hasSubRanges()) {
35 bool firstTime = true;
36 for (const auto &S : LI.subranges()) {
37 if (!S.liveAt(SI)) continue;
38 if (firstTime) {
39 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo())
40 << '\n';
41 firstTime = false;
42 }
43 dbgs() << " " << S << '\n';
44 ++Num;
45 }
46 } else if (LI.liveAt(SI)) {
47 dbgs() << " " << LI << '\n';
48 ++Num;
49 }
50 }
51 if (!Num) dbgs() << " <none>\n";
52 }
53 #endif
54
isEqual(const GCNRPTracker::LiveRegSet & S1,const GCNRPTracker::LiveRegSet & S2)55 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
56 const GCNRPTracker::LiveRegSet &S2) {
57 if (S1.size() != S2.size())
58 return false;
59
60 for (const auto &P : S1) {
61 auto I = S2.find(P.first);
62 if (I == S2.end() || I->second != P.second)
63 return false;
64 }
65 return true;
66 }
67
68
69 ///////////////////////////////////////////////////////////////////////////////
70 // GCNRegPressure
71
getRegKind(Register Reg,const MachineRegisterInfo & MRI)72 unsigned GCNRegPressure::getRegKind(Register Reg,
73 const MachineRegisterInfo &MRI) {
74 assert(Reg.isVirtual());
75 const auto RC = MRI.getRegClass(Reg);
76 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
77 return STI->isSGPRClass(RC) ?
78 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
79 STI->hasAGPRs(RC) ?
80 (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
81 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
82 }
83
inc(unsigned Reg,LaneBitmask PrevMask,LaneBitmask NewMask,const MachineRegisterInfo & MRI)84 void GCNRegPressure::inc(unsigned Reg,
85 LaneBitmask PrevMask,
86 LaneBitmask NewMask,
87 const MachineRegisterInfo &MRI) {
88 if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
89 SIRegisterInfo::getNumCoveredRegs(PrevMask))
90 return;
91
92 int Sign = 1;
93 if (NewMask < PrevMask) {
94 std::swap(NewMask, PrevMask);
95 Sign = -1;
96 }
97
98 switch (auto Kind = getRegKind(Reg, MRI)) {
99 case SGPR32:
100 case VGPR32:
101 case AGPR32:
102 Value[Kind] += Sign;
103 break;
104
105 case SGPR_TUPLE:
106 case VGPR_TUPLE:
107 case AGPR_TUPLE:
108 assert(PrevMask < NewMask);
109
110 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
111 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
112
113 if (PrevMask.none()) {
114 assert(NewMask.any());
115 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
116 }
117 break;
118
119 default: llvm_unreachable("Unknown register kind");
120 }
121 }
122
less(const GCNSubtarget & ST,const GCNRegPressure & O,unsigned MaxOccupancy) const123 bool GCNRegPressure::less(const GCNSubtarget &ST,
124 const GCNRegPressure& O,
125 unsigned MaxOccupancy) const {
126 const auto SGPROcc = std::min(MaxOccupancy,
127 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
128 const auto VGPROcc = std::min(MaxOccupancy,
129 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
130 const auto OtherSGPROcc = std::min(MaxOccupancy,
131 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
132 const auto OtherVGPROcc = std::min(MaxOccupancy,
133 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
134
135 const auto Occ = std::min(SGPROcc, VGPROcc);
136 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
137 if (Occ != OtherOcc)
138 return Occ > OtherOcc;
139
140 bool SGPRImportant = SGPROcc < VGPROcc;
141 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
142
143 // if both pressures disagree on what is more important compare vgprs
144 if (SGPRImportant != OtherSGPRImportant) {
145 SGPRImportant = false;
146 }
147
148 // compare large regs pressure
149 bool SGPRFirst = SGPRImportant;
150 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
151 if (SGPRFirst) {
152 auto SW = getSGPRTuplesWeight();
153 auto OtherSW = O.getSGPRTuplesWeight();
154 if (SW != OtherSW)
155 return SW < OtherSW;
156 } else {
157 auto VW = getVGPRTuplesWeight();
158 auto OtherVW = O.getVGPRTuplesWeight();
159 if (VW != OtherVW)
160 return VW < OtherVW;
161 }
162 }
163 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
164 (getVGPRNum() < O.getVGPRNum());
165 }
166
167 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
168 LLVM_DUMP_METHOD
print(raw_ostream & OS,const GCNSubtarget * ST) const169 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
170 OS << "VGPRs: " << Value[VGPR32] << ' ';
171 OS << "AGPRs: " << Value[AGPR32];
172 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
173 OS << ", SGPRs: " << getSGPRNum();
174 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
175 OS << ", LVGPR WT: " << getVGPRTuplesWeight()
176 << ", LSGPR WT: " << getSGPRTuplesWeight();
177 if (ST) OS << " -> Occ: " << getOccupancy(*ST);
178 OS << '\n';
179 }
180 #endif
181
getDefRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI)182 static LaneBitmask getDefRegMask(const MachineOperand &MO,
183 const MachineRegisterInfo &MRI) {
184 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
185
186 // We don't rely on read-undef flag because in case of tentative schedule
187 // tracking it isn't set correctly yet. This works correctly however since
188 // use mask has been tracked before using LIS.
189 return MO.getSubReg() == 0 ?
190 MRI.getMaxLaneMaskForVReg(MO.getReg()) :
191 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
192 }
193
getUsedRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI,const LiveIntervals & LIS)194 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
195 const MachineRegisterInfo &MRI,
196 const LiveIntervals &LIS) {
197 assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual());
198
199 if (auto SubReg = MO.getSubReg())
200 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
201
202 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
203 if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
204 return MaxMask;
205
206 // For a tentative schedule LIS isn't updated yet but livemask should remain
207 // the same on any schedule. Subreg defs can be reordered but they all must
208 // dominate uses anyway.
209 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
210 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
211 }
212
213 static SmallVector<RegisterMaskPair, 8>
collectVirtualRegUses(const MachineInstr & MI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)214 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
215 const MachineRegisterInfo &MRI) {
216 SmallVector<RegisterMaskPair, 8> Res;
217 for (const auto &MO : MI.operands()) {
218 if (!MO.isReg() || !MO.getReg().isVirtual())
219 continue;
220 if (!MO.isUse() || !MO.readsReg())
221 continue;
222
223 auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
224
225 auto Reg = MO.getReg();
226 auto I = llvm::find_if(
227 Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; });
228 if (I != Res.end())
229 I->LaneMask |= UsedMask;
230 else
231 Res.push_back(RegisterMaskPair(Reg, UsedMask));
232 }
233 return Res;
234 }
235
236 ///////////////////////////////////////////////////////////////////////////////
237 // GCNRPTracker
238
getLiveLaneMask(unsigned Reg,SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)239 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
240 SlotIndex SI,
241 const LiveIntervals &LIS,
242 const MachineRegisterInfo &MRI) {
243 LaneBitmask LiveMask;
244 const auto &LI = LIS.getInterval(Reg);
245 if (LI.hasSubRanges()) {
246 for (const auto &S : LI.subranges())
247 if (S.liveAt(SI)) {
248 LiveMask |= S.LaneMask;
249 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
250 LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
251 }
252 } else if (LI.liveAt(SI)) {
253 LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
254 }
255 return LiveMask;
256 }
257
getLiveRegs(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)258 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
259 const LiveIntervals &LIS,
260 const MachineRegisterInfo &MRI) {
261 GCNRPTracker::LiveRegSet LiveRegs;
262 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
263 auto Reg = Register::index2VirtReg(I);
264 if (!LIS.hasInterval(Reg))
265 continue;
266 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
267 if (LiveMask.any())
268 LiveRegs[Reg] = LiveMask;
269 }
270 return LiveRegs;
271 }
272
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy,bool After)273 void GCNRPTracker::reset(const MachineInstr &MI,
274 const LiveRegSet *LiveRegsCopy,
275 bool After) {
276 const MachineFunction &MF = *MI.getMF();
277 MRI = &MF.getRegInfo();
278 if (LiveRegsCopy) {
279 if (&LiveRegs != LiveRegsCopy)
280 LiveRegs = *LiveRegsCopy;
281 } else {
282 LiveRegs = After ? getLiveRegsAfter(MI, LIS)
283 : getLiveRegsBefore(MI, LIS);
284 }
285
286 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
287 }
288
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)289 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
290 const LiveRegSet *LiveRegsCopy) {
291 GCNRPTracker::reset(MI, LiveRegsCopy, true);
292 }
293
recede(const MachineInstr & MI)294 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
295 assert(MRI && "call reset first");
296
297 LastTrackedMI = &MI;
298
299 if (MI.isDebugInstr())
300 return;
301
302 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
303
304 // calc pressure at the MI (defs + uses)
305 auto AtMIPressure = CurPressure;
306 for (const auto &U : RegUses) {
307 auto LiveMask = LiveRegs[U.RegUnit];
308 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
309 }
310 // update max pressure
311 MaxPressure = max(AtMIPressure, MaxPressure);
312
313 for (const auto &MO : MI.operands()) {
314 if (!MO.isReg() || !MO.isDef() || !MO.getReg().isVirtual() || MO.isDead())
315 continue;
316
317 auto Reg = MO.getReg();
318 auto I = LiveRegs.find(Reg);
319 if (I == LiveRegs.end())
320 continue;
321 auto &LiveMask = I->second;
322 auto PrevMask = LiveMask;
323 LiveMask &= ~getDefRegMask(MO, *MRI);
324 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
325 if (LiveMask.none())
326 LiveRegs.erase(I);
327 }
328 for (const auto &U : RegUses) {
329 auto &LiveMask = LiveRegs[U.RegUnit];
330 auto PrevMask = LiveMask;
331 LiveMask |= U.LaneMask;
332 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
333 }
334 assert(CurPressure == getRegPressure(*MRI, LiveRegs));
335 }
336
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)337 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
338 const LiveRegSet *LiveRegsCopy) {
339 MRI = &MI.getParent()->getParent()->getRegInfo();
340 LastTrackedMI = nullptr;
341 MBBEnd = MI.getParent()->end();
342 NextMI = &MI;
343 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
344 if (NextMI == MBBEnd)
345 return false;
346 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
347 return true;
348 }
349
advanceBeforeNext()350 bool GCNDownwardRPTracker::advanceBeforeNext() {
351 assert(MRI && "call reset first");
352
353 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
354 if (NextMI == MBBEnd)
355 return false;
356
357 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
358 assert(SI.isValid());
359
360 // Remove dead registers or mask bits.
361 for (auto &It : LiveRegs) {
362 const LiveInterval &LI = LIS.getInterval(It.first);
363 if (LI.hasSubRanges()) {
364 for (const auto &S : LI.subranges()) {
365 if (!S.liveAt(SI)) {
366 auto PrevMask = It.second;
367 It.second &= ~S.LaneMask;
368 CurPressure.inc(It.first, PrevMask, It.second, *MRI);
369 }
370 }
371 } else if (!LI.liveAt(SI)) {
372 auto PrevMask = It.second;
373 It.second = LaneBitmask::getNone();
374 CurPressure.inc(It.first, PrevMask, It.second, *MRI);
375 }
376 if (It.second.none())
377 LiveRegs.erase(It.first);
378 }
379
380 MaxPressure = max(MaxPressure, CurPressure);
381
382 return true;
383 }
384
advanceToNext()385 void GCNDownwardRPTracker::advanceToNext() {
386 LastTrackedMI = &*NextMI++;
387
388 // Add new registers or mask bits.
389 for (const auto &MO : LastTrackedMI->operands()) {
390 if (!MO.isReg() || !MO.isDef())
391 continue;
392 Register Reg = MO.getReg();
393 if (!Reg.isVirtual())
394 continue;
395 auto &LiveMask = LiveRegs[Reg];
396 auto PrevMask = LiveMask;
397 LiveMask |= getDefRegMask(MO, *MRI);
398 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
399 }
400
401 MaxPressure = max(MaxPressure, CurPressure);
402 }
403
advance()404 bool GCNDownwardRPTracker::advance() {
405 // If we have just called reset live set is actual.
406 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
407 return false;
408 advanceToNext();
409 return true;
410 }
411
advance(MachineBasicBlock::const_iterator End)412 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
413 while (NextMI != End)
414 if (!advance()) return false;
415 return true;
416 }
417
advance(MachineBasicBlock::const_iterator Begin,MachineBasicBlock::const_iterator End,const LiveRegSet * LiveRegsCopy)418 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
419 MachineBasicBlock::const_iterator End,
420 const LiveRegSet *LiveRegsCopy) {
421 reset(*Begin, LiveRegsCopy);
422 return advance(End);
423 }
424
425 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
426 LLVM_DUMP_METHOD
reportMismatch(const GCNRPTracker::LiveRegSet & LISLR,const GCNRPTracker::LiveRegSet & TrackedLR,const TargetRegisterInfo * TRI)427 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
428 const GCNRPTracker::LiveRegSet &TrackedLR,
429 const TargetRegisterInfo *TRI) {
430 for (auto const &P : TrackedLR) {
431 auto I = LISLR.find(P.first);
432 if (I == LISLR.end()) {
433 dbgs() << " " << printReg(P.first, TRI)
434 << ":L" << PrintLaneMask(P.second)
435 << " isn't found in LIS reported set\n";
436 }
437 else if (I->second != P.second) {
438 dbgs() << " " << printReg(P.first, TRI)
439 << " masks doesn't match: LIS reported "
440 << PrintLaneMask(I->second)
441 << ", tracked "
442 << PrintLaneMask(P.second)
443 << '\n';
444 }
445 }
446 for (auto const &P : LISLR) {
447 auto I = TrackedLR.find(P.first);
448 if (I == TrackedLR.end()) {
449 dbgs() << " " << printReg(P.first, TRI)
450 << ":L" << PrintLaneMask(P.second)
451 << " isn't found in tracked set\n";
452 }
453 }
454 }
455
isValid() const456 bool GCNUpwardRPTracker::isValid() const {
457 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
458 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
459 const auto &TrackedLR = LiveRegs;
460
461 if (!isEqual(LISLR, TrackedLR)) {
462 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
463 " LIS reported livesets mismatch:\n";
464 printLivesAt(SI, LIS, *MRI);
465 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
466 return false;
467 }
468
469 auto LISPressure = getRegPressure(*MRI, LISLR);
470 if (LISPressure != CurPressure) {
471 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
472 CurPressure.print(dbgs());
473 dbgs() << "LIS rpt: ";
474 LISPressure.print(dbgs());
475 return false;
476 }
477 return true;
478 }
479
printLiveRegs(raw_ostream & OS,const LiveRegSet & LiveRegs,const MachineRegisterInfo & MRI)480 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
481 const MachineRegisterInfo &MRI) {
482 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
483 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
484 unsigned Reg = Register::index2VirtReg(I);
485 auto It = LiveRegs.find(Reg);
486 if (It != LiveRegs.end() && It->second.any())
487 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
488 << PrintLaneMask(It->second);
489 }
490 OS << '\n';
491 }
492 #endif
493