1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "GCNRegPressure.h"
15 #include "AMDGPUSubtarget.h"
16 #include "SIRegisterInfo.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/CodeGen/LiveInterval.h"
19 #include "llvm/CodeGen/LiveIntervals.h"
20 #include "llvm/CodeGen/MachineInstr.h"
21 #include "llvm/CodeGen/MachineOperand.h"
22 #include "llvm/CodeGen/MachineRegisterInfo.h"
23 #include "llvm/CodeGen/RegisterPressure.h"
24 #include "llvm/CodeGen/SlotIndexes.h"
25 #include "llvm/CodeGen/TargetRegisterInfo.h"
26 #include "llvm/Config/llvm-config.h"
27 #include "llvm/MC/LaneBitmask.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <algorithm>
33 #include <cassert>
34 
35 using namespace llvm;
36 
37 #define DEBUG_TYPE "machine-scheduler"
38 
39 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
40 LLVM_DUMP_METHOD
41 void llvm::printLivesAt(SlotIndex SI,
42                         const LiveIntervals &LIS,
43                         const MachineRegisterInfo &MRI) {
44   dbgs() << "Live regs at " << SI << ": "
45          << *LIS.getInstructionFromIndex(SI);
46   unsigned Num = 0;
47   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
48     const unsigned Reg = Register::index2VirtReg(I);
49     if (!LIS.hasInterval(Reg))
50       continue;
51     const auto &LI = LIS.getInterval(Reg);
52     if (LI.hasSubRanges()) {
53       bool firstTime = true;
54       for (const auto &S : LI.subranges()) {
55         if (!S.liveAt(SI)) continue;
56         if (firstTime) {
57           dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
58                  << '\n';
59           firstTime = false;
60         }
61         dbgs() << "  " << S << '\n';
62         ++Num;
63       }
64     } else if (LI.liveAt(SI)) {
65       dbgs() << "  " << LI << '\n';
66       ++Num;
67     }
68   }
69   if (!Num) dbgs() << "  <none>\n";
70 }
71 #endif
72 
73 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
74                    const GCNRPTracker::LiveRegSet &S2) {
75   if (S1.size() != S2.size())
76     return false;
77 
78   for (const auto &P : S1) {
79     auto I = S2.find(P.first);
80     if (I == S2.end() || I->second != P.second)
81       return false;
82   }
83   return true;
84 }
85 
86 
87 ///////////////////////////////////////////////////////////////////////////////
88 // GCNRegPressure
89 
90 unsigned GCNRegPressure::getRegKind(unsigned Reg,
91                                     const MachineRegisterInfo &MRI) {
92   assert(Register::isVirtualRegister(Reg));
93   const auto RC = MRI.getRegClass(Reg);
94   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
95   return STI->isSGPRClass(RC) ?
96     (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
97     STI->hasAGPRs(RC) ?
98       (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
99       (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
100 }
101 
102 void GCNRegPressure::inc(unsigned Reg,
103                          LaneBitmask PrevMask,
104                          LaneBitmask NewMask,
105                          const MachineRegisterInfo &MRI) {
106   if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
107       SIRegisterInfo::getNumCoveredRegs(PrevMask))
108     return;
109 
110   int Sign = 1;
111   if (NewMask < PrevMask) {
112     std::swap(NewMask, PrevMask);
113     Sign = -1;
114   }
115 
116   switch (auto Kind = getRegKind(Reg, MRI)) {
117   case SGPR32:
118   case VGPR32:
119   case AGPR32:
120     Value[Kind] += Sign;
121     break;
122 
123   case SGPR_TUPLE:
124   case VGPR_TUPLE:
125   case AGPR_TUPLE:
126     assert(PrevMask < NewMask);
127 
128     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
129       Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
130 
131     if (PrevMask.none()) {
132       assert(NewMask.any());
133       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
134     }
135     break;
136 
137   default: llvm_unreachable("Unknown register kind");
138   }
139 }
140 
141 bool GCNRegPressure::less(const GCNSubtarget &ST,
142                           const GCNRegPressure& O,
143                           unsigned MaxOccupancy) const {
144   const auto SGPROcc = std::min(MaxOccupancy,
145                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
146   const auto VGPROcc = std::min(MaxOccupancy,
147                                 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
148   const auto OtherSGPROcc = std::min(MaxOccupancy,
149                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
150   const auto OtherVGPROcc = std::min(MaxOccupancy,
151                                 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
152 
153   const auto Occ = std::min(SGPROcc, VGPROcc);
154   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
155   if (Occ != OtherOcc)
156     return Occ > OtherOcc;
157 
158   bool SGPRImportant = SGPROcc < VGPROcc;
159   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
160 
161   // if both pressures disagree on what is more important compare vgprs
162   if (SGPRImportant != OtherSGPRImportant) {
163     SGPRImportant = false;
164   }
165 
166   // compare large regs pressure
167   bool SGPRFirst = SGPRImportant;
168   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
169     if (SGPRFirst) {
170       auto SW = getSGPRTuplesWeight();
171       auto OtherSW = O.getSGPRTuplesWeight();
172       if (SW != OtherSW)
173         return SW < OtherSW;
174     } else {
175       auto VW = getVGPRTuplesWeight();
176       auto OtherVW = O.getVGPRTuplesWeight();
177       if (VW != OtherVW)
178         return VW < OtherVW;
179     }
180   }
181   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
182                          (getVGPRNum() < O.getVGPRNum());
183 }
184 
185 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
186 LLVM_DUMP_METHOD
187 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
188   OS << "VGPRs: " << Value[VGPR32] << ' ';
189   OS << "AGPRs: " << Value[AGPR32];
190   if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
191   OS << ", SGPRs: " << getSGPRNum();
192   if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
193   OS << ", LVGPR WT: " << getVGPRTuplesWeight()
194      << ", LSGPR WT: " << getSGPRTuplesWeight();
195   if (ST) OS << " -> Occ: " << getOccupancy(*ST);
196   OS << '\n';
197 }
198 #endif
199 
200 static LaneBitmask getDefRegMask(const MachineOperand &MO,
201                                  const MachineRegisterInfo &MRI) {
202   assert(MO.isDef() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
203 
204   // We don't rely on read-undef flag because in case of tentative schedule
205   // tracking it isn't set correctly yet. This works correctly however since
206   // use mask has been tracked before using LIS.
207   return MO.getSubReg() == 0 ?
208     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
209     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
210 }
211 
212 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
213                                   const MachineRegisterInfo &MRI,
214                                   const LiveIntervals &LIS) {
215   assert(MO.isUse() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
216 
217   if (auto SubReg = MO.getSubReg())
218     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
219 
220   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
221   if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
222     return MaxMask;
223 
224   // For a tentative schedule LIS isn't updated yet but livemask should remain
225   // the same on any schedule. Subreg defs can be reordered but they all must
226   // dominate uses anyway.
227   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
228   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
229 }
230 
231 static SmallVector<RegisterMaskPair, 8>
232 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
233                       const MachineRegisterInfo &MRI) {
234   SmallVector<RegisterMaskPair, 8> Res;
235   for (const auto &MO : MI.operands()) {
236     if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg()))
237       continue;
238     if (!MO.isUse() || !MO.readsReg())
239       continue;
240 
241     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
242 
243     auto Reg = MO.getReg();
244     auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
245       return RM.RegUnit == Reg;
246     });
247     if (I != Res.end())
248       I->LaneMask |= UsedMask;
249     else
250       Res.push_back(RegisterMaskPair(Reg, UsedMask));
251   }
252   return Res;
253 }
254 
255 ///////////////////////////////////////////////////////////////////////////////
256 // GCNRPTracker
257 
258 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
259                                   SlotIndex SI,
260                                   const LiveIntervals &LIS,
261                                   const MachineRegisterInfo &MRI) {
262   LaneBitmask LiveMask;
263   const auto &LI = LIS.getInterval(Reg);
264   if (LI.hasSubRanges()) {
265     for (const auto &S : LI.subranges())
266       if (S.liveAt(SI)) {
267         LiveMask |= S.LaneMask;
268         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
269                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
270       }
271   } else if (LI.liveAt(SI)) {
272     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
273   }
274   return LiveMask;
275 }
276 
277 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
278                                            const LiveIntervals &LIS,
279                                            const MachineRegisterInfo &MRI) {
280   GCNRPTracker::LiveRegSet LiveRegs;
281   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
282     auto Reg = Register::index2VirtReg(I);
283     if (!LIS.hasInterval(Reg))
284       continue;
285     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
286     if (LiveMask.any())
287       LiveRegs[Reg] = LiveMask;
288   }
289   return LiveRegs;
290 }
291 
292 void GCNRPTracker::reset(const MachineInstr &MI,
293                          const LiveRegSet *LiveRegsCopy,
294                          bool After) {
295   const MachineFunction &MF = *MI.getMF();
296   MRI = &MF.getRegInfo();
297   if (LiveRegsCopy) {
298     if (&LiveRegs != LiveRegsCopy)
299       LiveRegs = *LiveRegsCopy;
300   } else {
301     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
302                      : getLiveRegsBefore(MI, LIS);
303   }
304 
305   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
306 }
307 
308 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
309                                const LiveRegSet *LiveRegsCopy) {
310   GCNRPTracker::reset(MI, LiveRegsCopy, true);
311 }
312 
313 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
314   assert(MRI && "call reset first");
315 
316   LastTrackedMI = &MI;
317 
318   if (MI.isDebugInstr())
319     return;
320 
321   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
322 
323   // calc pressure at the MI (defs + uses)
324   auto AtMIPressure = CurPressure;
325   for (const auto &U : RegUses) {
326     auto LiveMask = LiveRegs[U.RegUnit];
327     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
328   }
329   // update max pressure
330   MaxPressure = max(AtMIPressure, MaxPressure);
331 
332   for (const auto &MO : MI.operands()) {
333     if (!MO.isReg() || !MO.isDef() ||
334         !Register::isVirtualRegister(MO.getReg()) || MO.isDead())
335       continue;
336 
337     auto Reg = MO.getReg();
338     auto I = LiveRegs.find(Reg);
339     if (I == LiveRegs.end())
340       continue;
341     auto &LiveMask = I->second;
342     auto PrevMask = LiveMask;
343     LiveMask &= ~getDefRegMask(MO, *MRI);
344     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
345     if (LiveMask.none())
346       LiveRegs.erase(I);
347   }
348   for (const auto &U : RegUses) {
349     auto &LiveMask = LiveRegs[U.RegUnit];
350     auto PrevMask = LiveMask;
351     LiveMask |= U.LaneMask;
352     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
353   }
354   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
355 }
356 
357 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
358                                  const LiveRegSet *LiveRegsCopy) {
359   MRI = &MI.getParent()->getParent()->getRegInfo();
360   LastTrackedMI = nullptr;
361   MBBEnd = MI.getParent()->end();
362   NextMI = &MI;
363   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
364   if (NextMI == MBBEnd)
365     return false;
366   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
367   return true;
368 }
369 
370 bool GCNDownwardRPTracker::advanceBeforeNext() {
371   assert(MRI && "call reset first");
372 
373   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
374   if (NextMI == MBBEnd)
375     return false;
376 
377   SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
378   assert(SI.isValid());
379 
380   // Remove dead registers or mask bits.
381   for (auto &It : LiveRegs) {
382     const LiveInterval &LI = LIS.getInterval(It.first);
383     if (LI.hasSubRanges()) {
384       for (const auto &S : LI.subranges()) {
385         if (!S.liveAt(SI)) {
386           auto PrevMask = It.second;
387           It.second &= ~S.LaneMask;
388           CurPressure.inc(It.first, PrevMask, It.second, *MRI);
389         }
390       }
391     } else if (!LI.liveAt(SI)) {
392       auto PrevMask = It.second;
393       It.second = LaneBitmask::getNone();
394       CurPressure.inc(It.first, PrevMask, It.second, *MRI);
395     }
396     if (It.second.none())
397       LiveRegs.erase(It.first);
398   }
399 
400   MaxPressure = max(MaxPressure, CurPressure);
401 
402   return true;
403 }
404 
405 void GCNDownwardRPTracker::advanceToNext() {
406   LastTrackedMI = &*NextMI++;
407 
408   // Add new registers or mask bits.
409   for (const auto &MO : LastTrackedMI->operands()) {
410     if (!MO.isReg() || !MO.isDef())
411       continue;
412     Register Reg = MO.getReg();
413     if (!Register::isVirtualRegister(Reg))
414       continue;
415     auto &LiveMask = LiveRegs[Reg];
416     auto PrevMask = LiveMask;
417     LiveMask |= getDefRegMask(MO, *MRI);
418     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
419   }
420 
421   MaxPressure = max(MaxPressure, CurPressure);
422 }
423 
424 bool GCNDownwardRPTracker::advance() {
425   // If we have just called reset live set is actual.
426   if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
427     return false;
428   advanceToNext();
429   return true;
430 }
431 
432 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
433   while (NextMI != End)
434     if (!advance()) return false;
435   return true;
436 }
437 
438 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
439                                    MachineBasicBlock::const_iterator End,
440                                    const LiveRegSet *LiveRegsCopy) {
441   reset(*Begin, LiveRegsCopy);
442   return advance(End);
443 }
444 
445 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
446 LLVM_DUMP_METHOD
447 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
448                            const GCNRPTracker::LiveRegSet &TrackedLR,
449                            const TargetRegisterInfo *TRI) {
450   for (auto const &P : TrackedLR) {
451     auto I = LISLR.find(P.first);
452     if (I == LISLR.end()) {
453       dbgs() << "  " << printReg(P.first, TRI)
454              << ":L" << PrintLaneMask(P.second)
455              << " isn't found in LIS reported set\n";
456     }
457     else if (I->second != P.second) {
458       dbgs() << "  " << printReg(P.first, TRI)
459         << " masks doesn't match: LIS reported "
460         << PrintLaneMask(I->second)
461         << ", tracked "
462         << PrintLaneMask(P.second)
463         << '\n';
464     }
465   }
466   for (auto const &P : LISLR) {
467     auto I = TrackedLR.find(P.first);
468     if (I == TrackedLR.end()) {
469       dbgs() << "  " << printReg(P.first, TRI)
470              << ":L" << PrintLaneMask(P.second)
471              << " isn't found in tracked set\n";
472     }
473   }
474 }
475 
476 bool GCNUpwardRPTracker::isValid() const {
477   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
478   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
479   const auto &TrackedLR = LiveRegs;
480 
481   if (!isEqual(LISLR, TrackedLR)) {
482     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
483               " LIS reported livesets mismatch:\n";
484     printLivesAt(SI, LIS, *MRI);
485     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
486     return false;
487   }
488 
489   auto LISPressure = getRegPressure(*MRI, LISLR);
490   if (LISPressure != CurPressure) {
491     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
492     CurPressure.print(dbgs());
493     dbgs() << "LIS rpt: ";
494     LISPressure.print(dbgs());
495     return false;
496   }
497   return true;
498 }
499 
500 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
501                                  const MachineRegisterInfo &MRI) {
502   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
503   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
504     unsigned Reg = Register::index2VirtReg(I);
505     auto It = LiveRegs.find(Reg);
506     if (It != LiveRegs.end() && It->second.any())
507       OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
508          << PrintLaneMask(It->second);
509   }
510   OS << '\n';
511 }
512 #endif
513