1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "GCNRegPressure.h"
15 #include "llvm/CodeGen/RegisterPressure.h"
16 
17 using namespace llvm;
18 
19 #define DEBUG_TYPE "machine-scheduler"
20 
21 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
22 LLVM_DUMP_METHOD
printLivesAt(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)23 void llvm::printLivesAt(SlotIndex SI,
24                         const LiveIntervals &LIS,
25                         const MachineRegisterInfo &MRI) {
26   dbgs() << "Live regs at " << SI << ": "
27          << *LIS.getInstructionFromIndex(SI);
28   unsigned Num = 0;
29   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
30     const unsigned Reg = Register::index2VirtReg(I);
31     if (!LIS.hasInterval(Reg))
32       continue;
33     const auto &LI = LIS.getInterval(Reg);
34     if (LI.hasSubRanges()) {
35       bool firstTime = true;
36       for (const auto &S : LI.subranges()) {
37         if (!S.liveAt(SI)) continue;
38         if (firstTime) {
39           dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
40                  << '\n';
41           firstTime = false;
42         }
43         dbgs() << "  " << S << '\n';
44         ++Num;
45       }
46     } else if (LI.liveAt(SI)) {
47       dbgs() << "  " << LI << '\n';
48       ++Num;
49     }
50   }
51   if (!Num) dbgs() << "  <none>\n";
52 }
53 #endif
54 
isEqual(const GCNRPTracker::LiveRegSet & S1,const GCNRPTracker::LiveRegSet & S2)55 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
56                    const GCNRPTracker::LiveRegSet &S2) {
57   if (S1.size() != S2.size())
58     return false;
59 
60   for (const auto &P : S1) {
61     auto I = S2.find(P.first);
62     if (I == S2.end() || I->second != P.second)
63       return false;
64   }
65   return true;
66 }
67 
68 
69 ///////////////////////////////////////////////////////////////////////////////
70 // GCNRegPressure
71 
getRegKind(Register Reg,const MachineRegisterInfo & MRI)72 unsigned GCNRegPressure::getRegKind(Register Reg,
73                                     const MachineRegisterInfo &MRI) {
74   assert(Reg.isVirtual());
75   const auto RC = MRI.getRegClass(Reg);
76   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
77   return STI->isSGPRClass(RC) ?
78     (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
79     STI->hasAGPRs(RC) ?
80       (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
81       (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
82 }
83 
inc(unsigned Reg,LaneBitmask PrevMask,LaneBitmask NewMask,const MachineRegisterInfo & MRI)84 void GCNRegPressure::inc(unsigned Reg,
85                          LaneBitmask PrevMask,
86                          LaneBitmask NewMask,
87                          const MachineRegisterInfo &MRI) {
88   if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
89       SIRegisterInfo::getNumCoveredRegs(PrevMask))
90     return;
91 
92   int Sign = 1;
93   if (NewMask < PrevMask) {
94     std::swap(NewMask, PrevMask);
95     Sign = -1;
96   }
97 
98   switch (auto Kind = getRegKind(Reg, MRI)) {
99   case SGPR32:
100   case VGPR32:
101   case AGPR32:
102     Value[Kind] += Sign;
103     break;
104 
105   case SGPR_TUPLE:
106   case VGPR_TUPLE:
107   case AGPR_TUPLE:
108     assert(PrevMask < NewMask);
109 
110     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
111       Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
112 
113     if (PrevMask.none()) {
114       assert(NewMask.any());
115       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
116     }
117     break;
118 
119   default: llvm_unreachable("Unknown register kind");
120   }
121 }
122 
less(const GCNSubtarget & ST,const GCNRegPressure & O,unsigned MaxOccupancy) const123 bool GCNRegPressure::less(const GCNSubtarget &ST,
124                           const GCNRegPressure& O,
125                           unsigned MaxOccupancy) const {
126   const auto SGPROcc = std::min(MaxOccupancy,
127                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
128   const auto VGPROcc = std::min(MaxOccupancy,
129                                 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
130   const auto OtherSGPROcc = std::min(MaxOccupancy,
131                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
132   const auto OtherVGPROcc = std::min(MaxOccupancy,
133                                 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
134 
135   const auto Occ = std::min(SGPROcc, VGPROcc);
136   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
137   if (Occ != OtherOcc)
138     return Occ > OtherOcc;
139 
140   bool SGPRImportant = SGPROcc < VGPROcc;
141   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
142 
143   // if both pressures disagree on what is more important compare vgprs
144   if (SGPRImportant != OtherSGPRImportant) {
145     SGPRImportant = false;
146   }
147 
148   // compare large regs pressure
149   bool SGPRFirst = SGPRImportant;
150   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
151     if (SGPRFirst) {
152       auto SW = getSGPRTuplesWeight();
153       auto OtherSW = O.getSGPRTuplesWeight();
154       if (SW != OtherSW)
155         return SW < OtherSW;
156     } else {
157       auto VW = getVGPRTuplesWeight();
158       auto OtherVW = O.getVGPRTuplesWeight();
159       if (VW != OtherVW)
160         return VW < OtherVW;
161     }
162   }
163   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
164                          (getVGPRNum() < O.getVGPRNum());
165 }
166 
167 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
168 LLVM_DUMP_METHOD
print(raw_ostream & OS,const GCNSubtarget * ST) const169 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
170   OS << "VGPRs: " << Value[VGPR32] << ' ';
171   OS << "AGPRs: " << Value[AGPR32];
172   if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
173   OS << ", SGPRs: " << getSGPRNum();
174   if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
175   OS << ", LVGPR WT: " << getVGPRTuplesWeight()
176      << ", LSGPR WT: " << getSGPRTuplesWeight();
177   if (ST) OS << " -> Occ: " << getOccupancy(*ST);
178   OS << '\n';
179 }
180 #endif
181 
getDefRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI)182 static LaneBitmask getDefRegMask(const MachineOperand &MO,
183                                  const MachineRegisterInfo &MRI) {
184   assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
185 
186   // We don't rely on read-undef flag because in case of tentative schedule
187   // tracking it isn't set correctly yet. This works correctly however since
188   // use mask has been tracked before using LIS.
189   return MO.getSubReg() == 0 ?
190     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
191     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
192 }
193 
getUsedRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI,const LiveIntervals & LIS)194 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
195                                   const MachineRegisterInfo &MRI,
196                                   const LiveIntervals &LIS) {
197   assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual());
198 
199   if (auto SubReg = MO.getSubReg())
200     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
201 
202   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
203   if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
204     return MaxMask;
205 
206   // For a tentative schedule LIS isn't updated yet but livemask should remain
207   // the same on any schedule. Subreg defs can be reordered but they all must
208   // dominate uses anyway.
209   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
210   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
211 }
212 
213 static SmallVector<RegisterMaskPair, 8>
collectVirtualRegUses(const MachineInstr & MI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)214 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
215                       const MachineRegisterInfo &MRI) {
216   SmallVector<RegisterMaskPair, 8> Res;
217   for (const auto &MO : MI.operands()) {
218     if (!MO.isReg() || !MO.getReg().isVirtual())
219       continue;
220     if (!MO.isUse() || !MO.readsReg())
221       continue;
222 
223     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
224 
225     auto Reg = MO.getReg();
226     auto I = llvm::find_if(
227         Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; });
228     if (I != Res.end())
229       I->LaneMask |= UsedMask;
230     else
231       Res.push_back(RegisterMaskPair(Reg, UsedMask));
232   }
233   return Res;
234 }
235 
236 ///////////////////////////////////////////////////////////////////////////////
237 // GCNRPTracker
238 
getLiveLaneMask(unsigned Reg,SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)239 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
240                                   SlotIndex SI,
241                                   const LiveIntervals &LIS,
242                                   const MachineRegisterInfo &MRI) {
243   LaneBitmask LiveMask;
244   const auto &LI = LIS.getInterval(Reg);
245   if (LI.hasSubRanges()) {
246     for (const auto &S : LI.subranges())
247       if (S.liveAt(SI)) {
248         LiveMask |= S.LaneMask;
249         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
250                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
251       }
252   } else if (LI.liveAt(SI)) {
253     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
254   }
255   return LiveMask;
256 }
257 
getLiveRegs(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)258 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
259                                            const LiveIntervals &LIS,
260                                            const MachineRegisterInfo &MRI) {
261   GCNRPTracker::LiveRegSet LiveRegs;
262   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
263     auto Reg = Register::index2VirtReg(I);
264     if (!LIS.hasInterval(Reg))
265       continue;
266     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
267     if (LiveMask.any())
268       LiveRegs[Reg] = LiveMask;
269   }
270   return LiveRegs;
271 }
272 
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy,bool After)273 void GCNRPTracker::reset(const MachineInstr &MI,
274                          const LiveRegSet *LiveRegsCopy,
275                          bool After) {
276   const MachineFunction &MF = *MI.getMF();
277   MRI = &MF.getRegInfo();
278   if (LiveRegsCopy) {
279     if (&LiveRegs != LiveRegsCopy)
280       LiveRegs = *LiveRegsCopy;
281   } else {
282     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
283                      : getLiveRegsBefore(MI, LIS);
284   }
285 
286   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
287 }
288 
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)289 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
290                                const LiveRegSet *LiveRegsCopy) {
291   GCNRPTracker::reset(MI, LiveRegsCopy, true);
292 }
293 
recede(const MachineInstr & MI)294 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
295   assert(MRI && "call reset first");
296 
297   LastTrackedMI = &MI;
298 
299   if (MI.isDebugInstr())
300     return;
301 
302   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
303 
304   // calc pressure at the MI (defs + uses)
305   auto AtMIPressure = CurPressure;
306   for (const auto &U : RegUses) {
307     auto LiveMask = LiveRegs[U.RegUnit];
308     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
309   }
310   // update max pressure
311   MaxPressure = max(AtMIPressure, MaxPressure);
312 
313   for (const auto &MO : MI.operands()) {
314     if (!MO.isReg() || !MO.isDef() || !MO.getReg().isVirtual() || MO.isDead())
315       continue;
316 
317     auto Reg = MO.getReg();
318     auto I = LiveRegs.find(Reg);
319     if (I == LiveRegs.end())
320       continue;
321     auto &LiveMask = I->second;
322     auto PrevMask = LiveMask;
323     LiveMask &= ~getDefRegMask(MO, *MRI);
324     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
325     if (LiveMask.none())
326       LiveRegs.erase(I);
327   }
328   for (const auto &U : RegUses) {
329     auto &LiveMask = LiveRegs[U.RegUnit];
330     auto PrevMask = LiveMask;
331     LiveMask |= U.LaneMask;
332     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
333   }
334   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
335 }
336 
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)337 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
338                                  const LiveRegSet *LiveRegsCopy) {
339   MRI = &MI.getParent()->getParent()->getRegInfo();
340   LastTrackedMI = nullptr;
341   MBBEnd = MI.getParent()->end();
342   NextMI = &MI;
343   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
344   if (NextMI == MBBEnd)
345     return false;
346   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
347   return true;
348 }
349 
advanceBeforeNext()350 bool GCNDownwardRPTracker::advanceBeforeNext() {
351   assert(MRI && "call reset first");
352 
353   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
354   if (NextMI == MBBEnd)
355     return false;
356 
357   SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
358   assert(SI.isValid());
359 
360   // Remove dead registers or mask bits.
361   for (auto &It : LiveRegs) {
362     const LiveInterval &LI = LIS.getInterval(It.first);
363     if (LI.hasSubRanges()) {
364       for (const auto &S : LI.subranges()) {
365         if (!S.liveAt(SI)) {
366           auto PrevMask = It.second;
367           It.second &= ~S.LaneMask;
368           CurPressure.inc(It.first, PrevMask, It.second, *MRI);
369         }
370       }
371     } else if (!LI.liveAt(SI)) {
372       auto PrevMask = It.second;
373       It.second = LaneBitmask::getNone();
374       CurPressure.inc(It.first, PrevMask, It.second, *MRI);
375     }
376     if (It.second.none())
377       LiveRegs.erase(It.first);
378   }
379 
380   MaxPressure = max(MaxPressure, CurPressure);
381 
382   return true;
383 }
384 
advanceToNext()385 void GCNDownwardRPTracker::advanceToNext() {
386   LastTrackedMI = &*NextMI++;
387 
388   // Add new registers or mask bits.
389   for (const auto &MO : LastTrackedMI->operands()) {
390     if (!MO.isReg() || !MO.isDef())
391       continue;
392     Register Reg = MO.getReg();
393     if (!Reg.isVirtual())
394       continue;
395     auto &LiveMask = LiveRegs[Reg];
396     auto PrevMask = LiveMask;
397     LiveMask |= getDefRegMask(MO, *MRI);
398     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
399   }
400 
401   MaxPressure = max(MaxPressure, CurPressure);
402 }
403 
advance()404 bool GCNDownwardRPTracker::advance() {
405   // If we have just called reset live set is actual.
406   if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
407     return false;
408   advanceToNext();
409   return true;
410 }
411 
advance(MachineBasicBlock::const_iterator End)412 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
413   while (NextMI != End)
414     if (!advance()) return false;
415   return true;
416 }
417 
advance(MachineBasicBlock::const_iterator Begin,MachineBasicBlock::const_iterator End,const LiveRegSet * LiveRegsCopy)418 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
419                                    MachineBasicBlock::const_iterator End,
420                                    const LiveRegSet *LiveRegsCopy) {
421   reset(*Begin, LiveRegsCopy);
422   return advance(End);
423 }
424 
425 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
426 LLVM_DUMP_METHOD
reportMismatch(const GCNRPTracker::LiveRegSet & LISLR,const GCNRPTracker::LiveRegSet & TrackedLR,const TargetRegisterInfo * TRI)427 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
428                            const GCNRPTracker::LiveRegSet &TrackedLR,
429                            const TargetRegisterInfo *TRI) {
430   for (auto const &P : TrackedLR) {
431     auto I = LISLR.find(P.first);
432     if (I == LISLR.end()) {
433       dbgs() << "  " << printReg(P.first, TRI)
434              << ":L" << PrintLaneMask(P.second)
435              << " isn't found in LIS reported set\n";
436     }
437     else if (I->second != P.second) {
438       dbgs() << "  " << printReg(P.first, TRI)
439         << " masks doesn't match: LIS reported "
440         << PrintLaneMask(I->second)
441         << ", tracked "
442         << PrintLaneMask(P.second)
443         << '\n';
444     }
445   }
446   for (auto const &P : LISLR) {
447     auto I = TrackedLR.find(P.first);
448     if (I == TrackedLR.end()) {
449       dbgs() << "  " << printReg(P.first, TRI)
450              << ":L" << PrintLaneMask(P.second)
451              << " isn't found in tracked set\n";
452     }
453   }
454 }
455 
isValid() const456 bool GCNUpwardRPTracker::isValid() const {
457   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
458   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
459   const auto &TrackedLR = LiveRegs;
460 
461   if (!isEqual(LISLR, TrackedLR)) {
462     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
463               " LIS reported livesets mismatch:\n";
464     printLivesAt(SI, LIS, *MRI);
465     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
466     return false;
467   }
468 
469   auto LISPressure = getRegPressure(*MRI, LISLR);
470   if (LISPressure != CurPressure) {
471     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
472     CurPressure.print(dbgs());
473     dbgs() << "LIS rpt: ";
474     LISPressure.print(dbgs());
475     return false;
476   }
477   return true;
478 }
479 
printLiveRegs(raw_ostream & OS,const LiveRegSet & LiveRegs,const MachineRegisterInfo & MRI)480 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
481                                  const MachineRegisterInfo &MRI) {
482   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
483   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
484     unsigned Reg = Register::index2VirtReg(I);
485     auto It = LiveRegs.find(Reg);
486     if (It != LiveRegs.end() && It->second.any())
487       OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
488          << PrintLaneMask(It->second);
489   }
490   OS << '\n';
491 }
492 #endif
493