1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13
14 #include "GCNRegPressure.h"
15 #include "AMDGPUSubtarget.h"
16 #include "SIRegisterInfo.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/CodeGen/LiveInterval.h"
19 #include "llvm/CodeGen/LiveIntervals.h"
20 #include "llvm/CodeGen/MachineInstr.h"
21 #include "llvm/CodeGen/MachineOperand.h"
22 #include "llvm/CodeGen/MachineRegisterInfo.h"
23 #include "llvm/CodeGen/RegisterPressure.h"
24 #include "llvm/CodeGen/SlotIndexes.h"
25 #include "llvm/CodeGen/TargetRegisterInfo.h"
26 #include "llvm/Config/llvm-config.h"
27 #include "llvm/MC/LaneBitmask.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <algorithm>
33 #include <cassert>
34
35 using namespace llvm;
36
37 #define DEBUG_TYPE "machine-scheduler"
38
39 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
40 LLVM_DUMP_METHOD
printLivesAt(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)41 void llvm::printLivesAt(SlotIndex SI,
42 const LiveIntervals &LIS,
43 const MachineRegisterInfo &MRI) {
44 dbgs() << "Live regs at " << SI << ": "
45 << *LIS.getInstructionFromIndex(SI);
46 unsigned Num = 0;
47 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
48 const unsigned Reg = Register::index2VirtReg(I);
49 if (!LIS.hasInterval(Reg))
50 continue;
51 const auto &LI = LIS.getInterval(Reg);
52 if (LI.hasSubRanges()) {
53 bool firstTime = true;
54 for (const auto &S : LI.subranges()) {
55 if (!S.liveAt(SI)) continue;
56 if (firstTime) {
57 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo())
58 << '\n';
59 firstTime = false;
60 }
61 dbgs() << " " << S << '\n';
62 ++Num;
63 }
64 } else if (LI.liveAt(SI)) {
65 dbgs() << " " << LI << '\n';
66 ++Num;
67 }
68 }
69 if (!Num) dbgs() << " <none>\n";
70 }
71 #endif
72
isEqual(const GCNRPTracker::LiveRegSet & S1,const GCNRPTracker::LiveRegSet & S2)73 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
74 const GCNRPTracker::LiveRegSet &S2) {
75 if (S1.size() != S2.size())
76 return false;
77
78 for (const auto &P : S1) {
79 auto I = S2.find(P.first);
80 if (I == S2.end() || I->second != P.second)
81 return false;
82 }
83 return true;
84 }
85
86
87 ///////////////////////////////////////////////////////////////////////////////
88 // GCNRegPressure
89
getRegKind(unsigned Reg,const MachineRegisterInfo & MRI)90 unsigned GCNRegPressure::getRegKind(unsigned Reg,
91 const MachineRegisterInfo &MRI) {
92 assert(Register::isVirtualRegister(Reg));
93 const auto RC = MRI.getRegClass(Reg);
94 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
95 return STI->isSGPRClass(RC) ?
96 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
97 STI->hasAGPRs(RC) ?
98 (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
99 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
100 }
101
inc(unsigned Reg,LaneBitmask PrevMask,LaneBitmask NewMask,const MachineRegisterInfo & MRI)102 void GCNRegPressure::inc(unsigned Reg,
103 LaneBitmask PrevMask,
104 LaneBitmask NewMask,
105 const MachineRegisterInfo &MRI) {
106 if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
107 SIRegisterInfo::getNumCoveredRegs(PrevMask))
108 return;
109
110 int Sign = 1;
111 if (NewMask < PrevMask) {
112 std::swap(NewMask, PrevMask);
113 Sign = -1;
114 }
115
116 switch (auto Kind = getRegKind(Reg, MRI)) {
117 case SGPR32:
118 case VGPR32:
119 case AGPR32:
120 Value[Kind] += Sign;
121 break;
122
123 case SGPR_TUPLE:
124 case VGPR_TUPLE:
125 case AGPR_TUPLE:
126 assert(PrevMask < NewMask);
127
128 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
129 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
130
131 if (PrevMask.none()) {
132 assert(NewMask.any());
133 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
134 }
135 break;
136
137 default: llvm_unreachable("Unknown register kind");
138 }
139 }
140
less(const GCNSubtarget & ST,const GCNRegPressure & O,unsigned MaxOccupancy) const141 bool GCNRegPressure::less(const GCNSubtarget &ST,
142 const GCNRegPressure& O,
143 unsigned MaxOccupancy) const {
144 const auto SGPROcc = std::min(MaxOccupancy,
145 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
146 const auto VGPROcc = std::min(MaxOccupancy,
147 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
148 const auto OtherSGPROcc = std::min(MaxOccupancy,
149 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
150 const auto OtherVGPROcc = std::min(MaxOccupancy,
151 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
152
153 const auto Occ = std::min(SGPROcc, VGPROcc);
154 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
155 if (Occ != OtherOcc)
156 return Occ > OtherOcc;
157
158 bool SGPRImportant = SGPROcc < VGPROcc;
159 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
160
161 // if both pressures disagree on what is more important compare vgprs
162 if (SGPRImportant != OtherSGPRImportant) {
163 SGPRImportant = false;
164 }
165
166 // compare large regs pressure
167 bool SGPRFirst = SGPRImportant;
168 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
169 if (SGPRFirst) {
170 auto SW = getSGPRTuplesWeight();
171 auto OtherSW = O.getSGPRTuplesWeight();
172 if (SW != OtherSW)
173 return SW < OtherSW;
174 } else {
175 auto VW = getVGPRTuplesWeight();
176 auto OtherVW = O.getVGPRTuplesWeight();
177 if (VW != OtherVW)
178 return VW < OtherVW;
179 }
180 }
181 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
182 (getVGPRNum() < O.getVGPRNum());
183 }
184
185 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
186 LLVM_DUMP_METHOD
print(raw_ostream & OS,const GCNSubtarget * ST) const187 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
188 OS << "VGPRs: " << Value[VGPR32] << ' ';
189 OS << "AGPRs: " << Value[AGPR32];
190 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
191 OS << ", SGPRs: " << getSGPRNum();
192 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
193 OS << ", LVGPR WT: " << getVGPRTuplesWeight()
194 << ", LSGPR WT: " << getSGPRTuplesWeight();
195 if (ST) OS << " -> Occ: " << getOccupancy(*ST);
196 OS << '\n';
197 }
198 #endif
199
getDefRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI)200 static LaneBitmask getDefRegMask(const MachineOperand &MO,
201 const MachineRegisterInfo &MRI) {
202 assert(MO.isDef() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
203
204 // We don't rely on read-undef flag because in case of tentative schedule
205 // tracking it isn't set correctly yet. This works correctly however since
206 // use mask has been tracked before using LIS.
207 return MO.getSubReg() == 0 ?
208 MRI.getMaxLaneMaskForVReg(MO.getReg()) :
209 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
210 }
211
getUsedRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI,const LiveIntervals & LIS)212 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
213 const MachineRegisterInfo &MRI,
214 const LiveIntervals &LIS) {
215 assert(MO.isUse() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
216
217 if (auto SubReg = MO.getSubReg())
218 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
219
220 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
221 if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
222 return MaxMask;
223
224 // For a tentative schedule LIS isn't updated yet but livemask should remain
225 // the same on any schedule. Subreg defs can be reordered but they all must
226 // dominate uses anyway.
227 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
228 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
229 }
230
231 static SmallVector<RegisterMaskPair, 8>
collectVirtualRegUses(const MachineInstr & MI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)232 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
233 const MachineRegisterInfo &MRI) {
234 SmallVector<RegisterMaskPair, 8> Res;
235 for (const auto &MO : MI.operands()) {
236 if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg()))
237 continue;
238 if (!MO.isUse() || !MO.readsReg())
239 continue;
240
241 auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
242
243 auto Reg = MO.getReg();
244 auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
245 return RM.RegUnit == Reg;
246 });
247 if (I != Res.end())
248 I->LaneMask |= UsedMask;
249 else
250 Res.push_back(RegisterMaskPair(Reg, UsedMask));
251 }
252 return Res;
253 }
254
255 ///////////////////////////////////////////////////////////////////////////////
256 // GCNRPTracker
257
getLiveLaneMask(unsigned Reg,SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)258 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
259 SlotIndex SI,
260 const LiveIntervals &LIS,
261 const MachineRegisterInfo &MRI) {
262 LaneBitmask LiveMask;
263 const auto &LI = LIS.getInterval(Reg);
264 if (LI.hasSubRanges()) {
265 for (const auto &S : LI.subranges())
266 if (S.liveAt(SI)) {
267 LiveMask |= S.LaneMask;
268 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
269 LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
270 }
271 } else if (LI.liveAt(SI)) {
272 LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
273 }
274 return LiveMask;
275 }
276
getLiveRegs(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)277 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
278 const LiveIntervals &LIS,
279 const MachineRegisterInfo &MRI) {
280 GCNRPTracker::LiveRegSet LiveRegs;
281 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
282 auto Reg = Register::index2VirtReg(I);
283 if (!LIS.hasInterval(Reg))
284 continue;
285 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
286 if (LiveMask.any())
287 LiveRegs[Reg] = LiveMask;
288 }
289 return LiveRegs;
290 }
291
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy,bool After)292 void GCNRPTracker::reset(const MachineInstr &MI,
293 const LiveRegSet *LiveRegsCopy,
294 bool After) {
295 const MachineFunction &MF = *MI.getMF();
296 MRI = &MF.getRegInfo();
297 if (LiveRegsCopy) {
298 if (&LiveRegs != LiveRegsCopy)
299 LiveRegs = *LiveRegsCopy;
300 } else {
301 LiveRegs = After ? getLiveRegsAfter(MI, LIS)
302 : getLiveRegsBefore(MI, LIS);
303 }
304
305 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
306 }
307
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)308 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
309 const LiveRegSet *LiveRegsCopy) {
310 GCNRPTracker::reset(MI, LiveRegsCopy, true);
311 }
312
recede(const MachineInstr & MI)313 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
314 assert(MRI && "call reset first");
315
316 LastTrackedMI = &MI;
317
318 if (MI.isDebugInstr())
319 return;
320
321 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
322
323 // calc pressure at the MI (defs + uses)
324 auto AtMIPressure = CurPressure;
325 for (const auto &U : RegUses) {
326 auto LiveMask = LiveRegs[U.RegUnit];
327 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
328 }
329 // update max pressure
330 MaxPressure = max(AtMIPressure, MaxPressure);
331
332 for (const auto &MO : MI.operands()) {
333 if (!MO.isReg() || !MO.isDef() ||
334 !Register::isVirtualRegister(MO.getReg()) || MO.isDead())
335 continue;
336
337 auto Reg = MO.getReg();
338 auto I = LiveRegs.find(Reg);
339 if (I == LiveRegs.end())
340 continue;
341 auto &LiveMask = I->second;
342 auto PrevMask = LiveMask;
343 LiveMask &= ~getDefRegMask(MO, *MRI);
344 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
345 if (LiveMask.none())
346 LiveRegs.erase(I);
347 }
348 for (const auto &U : RegUses) {
349 auto &LiveMask = LiveRegs[U.RegUnit];
350 auto PrevMask = LiveMask;
351 LiveMask |= U.LaneMask;
352 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
353 }
354 assert(CurPressure == getRegPressure(*MRI, LiveRegs));
355 }
356
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)357 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
358 const LiveRegSet *LiveRegsCopy) {
359 MRI = &MI.getParent()->getParent()->getRegInfo();
360 LastTrackedMI = nullptr;
361 MBBEnd = MI.getParent()->end();
362 NextMI = &MI;
363 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
364 if (NextMI == MBBEnd)
365 return false;
366 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
367 return true;
368 }
369
advanceBeforeNext()370 bool GCNDownwardRPTracker::advanceBeforeNext() {
371 assert(MRI && "call reset first");
372
373 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
374 if (NextMI == MBBEnd)
375 return false;
376
377 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
378 assert(SI.isValid());
379
380 // Remove dead registers or mask bits.
381 for (auto &It : LiveRegs) {
382 const LiveInterval &LI = LIS.getInterval(It.first);
383 if (LI.hasSubRanges()) {
384 for (const auto &S : LI.subranges()) {
385 if (!S.liveAt(SI)) {
386 auto PrevMask = It.second;
387 It.second &= ~S.LaneMask;
388 CurPressure.inc(It.first, PrevMask, It.second, *MRI);
389 }
390 }
391 } else if (!LI.liveAt(SI)) {
392 auto PrevMask = It.second;
393 It.second = LaneBitmask::getNone();
394 CurPressure.inc(It.first, PrevMask, It.second, *MRI);
395 }
396 if (It.second.none())
397 LiveRegs.erase(It.first);
398 }
399
400 MaxPressure = max(MaxPressure, CurPressure);
401
402 return true;
403 }
404
advanceToNext()405 void GCNDownwardRPTracker::advanceToNext() {
406 LastTrackedMI = &*NextMI++;
407
408 // Add new registers or mask bits.
409 for (const auto &MO : LastTrackedMI->operands()) {
410 if (!MO.isReg() || !MO.isDef())
411 continue;
412 Register Reg = MO.getReg();
413 if (!Register::isVirtualRegister(Reg))
414 continue;
415 auto &LiveMask = LiveRegs[Reg];
416 auto PrevMask = LiveMask;
417 LiveMask |= getDefRegMask(MO, *MRI);
418 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
419 }
420
421 MaxPressure = max(MaxPressure, CurPressure);
422 }
423
advance()424 bool GCNDownwardRPTracker::advance() {
425 // If we have just called reset live set is actual.
426 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
427 return false;
428 advanceToNext();
429 return true;
430 }
431
advance(MachineBasicBlock::const_iterator End)432 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
433 while (NextMI != End)
434 if (!advance()) return false;
435 return true;
436 }
437
advance(MachineBasicBlock::const_iterator Begin,MachineBasicBlock::const_iterator End,const LiveRegSet * LiveRegsCopy)438 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
439 MachineBasicBlock::const_iterator End,
440 const LiveRegSet *LiveRegsCopy) {
441 reset(*Begin, LiveRegsCopy);
442 return advance(End);
443 }
444
445 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
446 LLVM_DUMP_METHOD
reportMismatch(const GCNRPTracker::LiveRegSet & LISLR,const GCNRPTracker::LiveRegSet & TrackedLR,const TargetRegisterInfo * TRI)447 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
448 const GCNRPTracker::LiveRegSet &TrackedLR,
449 const TargetRegisterInfo *TRI) {
450 for (auto const &P : TrackedLR) {
451 auto I = LISLR.find(P.first);
452 if (I == LISLR.end()) {
453 dbgs() << " " << printReg(P.first, TRI)
454 << ":L" << PrintLaneMask(P.second)
455 << " isn't found in LIS reported set\n";
456 }
457 else if (I->second != P.second) {
458 dbgs() << " " << printReg(P.first, TRI)
459 << " masks doesn't match: LIS reported "
460 << PrintLaneMask(I->second)
461 << ", tracked "
462 << PrintLaneMask(P.second)
463 << '\n';
464 }
465 }
466 for (auto const &P : LISLR) {
467 auto I = TrackedLR.find(P.first);
468 if (I == TrackedLR.end()) {
469 dbgs() << " " << printReg(P.first, TRI)
470 << ":L" << PrintLaneMask(P.second)
471 << " isn't found in tracked set\n";
472 }
473 }
474 }
475
isValid() const476 bool GCNUpwardRPTracker::isValid() const {
477 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
478 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
479 const auto &TrackedLR = LiveRegs;
480
481 if (!isEqual(LISLR, TrackedLR)) {
482 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
483 " LIS reported livesets mismatch:\n";
484 printLivesAt(SI, LIS, *MRI);
485 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
486 return false;
487 }
488
489 auto LISPressure = getRegPressure(*MRI, LISLR);
490 if (LISPressure != CurPressure) {
491 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
492 CurPressure.print(dbgs());
493 dbgs() << "LIS rpt: ";
494 LISPressure.print(dbgs());
495 return false;
496 }
497 return true;
498 }
499
printLiveRegs(raw_ostream & OS,const LiveRegSet & LiveRegs,const MachineRegisterInfo & MRI)500 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
501 const MachineRegisterInfo &MRI) {
502 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
503 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
504 unsigned Reg = Register::index2VirtReg(I);
505 auto It = LiveRegs.find(Reg);
506 if (It != LiveRegs.end() && It->second.any())
507 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
508 << PrintLaneMask(It->second);
509 }
510 OS << '\n';
511 }
512 #endif
513