1 //===- RDFRegisters.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ADT/BitVector.h"
10 #include "llvm/CodeGen/MachineFunction.h"
11 #include "llvm/CodeGen/MachineInstr.h"
12 #include "llvm/CodeGen/MachineOperand.h"
13 #include "llvm/CodeGen/RDFRegisters.h"
14 #include "llvm/CodeGen/TargetRegisterInfo.h"
15 #include "llvm/MC/LaneBitmask.h"
16 #include "llvm/MC/MCRegisterInfo.h"
17 #include "llvm/Support/ErrorHandling.h"
18 #include "llvm/Support/Format.h"
19 #include "llvm/Support/MathExtras.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include <cassert>
22 #include <cstdint>
23 #include <set>
24 #include <utility>
25 
26 namespace llvm::rdf {
27 
28 PhysicalRegisterInfo::PhysicalRegisterInfo(const TargetRegisterInfo &tri,
29                                            const MachineFunction &mf)
30     : TRI(tri) {
31   RegInfos.resize(TRI.getNumRegs());
32 
33   BitVector BadRC(TRI.getNumRegs());
34   for (const TargetRegisterClass *RC : TRI.regclasses()) {
35     for (MCPhysReg R : *RC) {
36       RegInfo &RI = RegInfos[R];
37       if (RI.RegClass != nullptr && !BadRC[R]) {
38         if (RC->LaneMask != RI.RegClass->LaneMask) {
39           BadRC.set(R);
40           RI.RegClass = nullptr;
41         }
42       } else
43         RI.RegClass = RC;
44     }
45   }
46 
47   UnitInfos.resize(TRI.getNumRegUnits());
48 
49   for (uint32_t U = 0, NU = TRI.getNumRegUnits(); U != NU; ++U) {
50     if (UnitInfos[U].Reg != 0)
51       continue;
52     MCRegUnitRootIterator R(U, &TRI);
53     assert(R.isValid());
54     RegisterId F = *R;
55     ++R;
56     if (R.isValid()) {
57       UnitInfos[U].Mask = LaneBitmask::getAll();
58       UnitInfos[U].Reg = F;
59     } else {
60       for (MCRegUnitMaskIterator I(F, &TRI); I.isValid(); ++I) {
61         std::pair<uint32_t, LaneBitmask> P = *I;
62         UnitInfo &UI = UnitInfos[P.first];
63         UI.Reg = F;
64         UI.Mask = P.second;
65       }
66     }
67   }
68 
69   for (const uint32_t *RM : TRI.getRegMasks())
70     RegMasks.insert(RM);
71   for (const MachineBasicBlock &B : mf)
72     for (const MachineInstr &In : B)
73       for (const MachineOperand &Op : In.operands())
74         if (Op.isRegMask())
75           RegMasks.insert(Op.getRegMask());
76 
77   MaskInfos.resize(RegMasks.size() + 1);
78   for (uint32_t M = 1, NM = RegMasks.size(); M <= NM; ++M) {
79     BitVector PU(TRI.getNumRegUnits());
80     const uint32_t *MB = RegMasks.get(M);
81     for (unsigned I = 1, E = TRI.getNumRegs(); I != E; ++I) {
82       if (!(MB[I / 32] & (1u << (I % 32))))
83         continue;
84       for (MCRegUnit Unit : TRI.regunits(MCRegister::from(I)))
85         PU.set(Unit);
86     }
87     MaskInfos[M].Units = PU.flip();
88   }
89 
90   AliasInfos.resize(TRI.getNumRegUnits());
91   for (uint32_t U = 0, NU = TRI.getNumRegUnits(); U != NU; ++U) {
92     BitVector AS(TRI.getNumRegs());
93     for (MCRegUnitRootIterator R(U, &TRI); R.isValid(); ++R)
94       for (MCPhysReg S : TRI.superregs_inclusive(*R))
95         AS.set(S);
96     AliasInfos[U].Regs = AS;
97   }
98 }
99 
100 bool PhysicalRegisterInfo::alias(RegisterRef RA, RegisterRef RB) const {
101   return !disjoint(getUnits(RA), getUnits(RB));
102 }
103 
104 std::set<RegisterId> PhysicalRegisterInfo::getAliasSet(RegisterId Reg) const {
105   // Do not include Reg in the alias set.
106   std::set<RegisterId> AS;
107   assert(!RegisterRef::isUnitId(Reg) && "No units allowed");
108   if (RegisterRef::isMaskId(Reg)) {
109     // XXX SLOW
110     const uint32_t *MB = getRegMaskBits(Reg);
111     for (unsigned i = 1, e = TRI.getNumRegs(); i != e; ++i) {
112       if (MB[i / 32] & (1u << (i % 32)))
113         continue;
114       AS.insert(i);
115     }
116     return AS;
117   }
118 
119   assert(RegisterRef::isRegId(Reg));
120   for (MCRegAliasIterator AI(Reg, &TRI, false); AI.isValid(); ++AI)
121     AS.insert(*AI);
122 
123   return AS;
124 }
125 
126 std::set<RegisterId> PhysicalRegisterInfo::getUnits(RegisterRef RR) const {
127   std::set<RegisterId> Units;
128 
129   if (RR.Reg == 0)
130     return Units; // Empty
131 
132   if (RR.isReg()) {
133     if (RR.Mask.none())
134       return Units; // Empty
135     for (MCRegUnitMaskIterator UM(RR.idx(), &TRI); UM.isValid(); ++UM) {
136       auto [U, M] = *UM;
137       if ((M & RR.Mask).any())
138         Units.insert(U);
139     }
140     return Units;
141   }
142 
143   assert(RR.isMask());
144   unsigned NumRegs = TRI.getNumRegs();
145   const uint32_t *MB = getRegMaskBits(RR.idx());
146   for (unsigned I = 0, E = (NumRegs + 31) / 32; I != E; ++I) {
147     uint32_t C = ~MB[I]; // Clobbered regs
148     if (I == 0)          // Reg 0 should be ignored
149       C &= maskLeadingOnes<unsigned>(31);
150     if (I + 1 == E && NumRegs % 32 != 0) // Last word may be partial
151       C &= maskTrailingOnes<unsigned>(NumRegs % 32);
152     if (C == 0)
153       continue;
154     while (C != 0) {
155       unsigned T = llvm::countr_zero(C);
156       unsigned CR = 32 * I + T; // Clobbered reg
157       for (MCRegUnit U : TRI.regunits(CR))
158         Units.insert(U);
159       C &= ~(1u << T);
160     }
161   }
162   return Units;
163 }
164 
165 RegisterRef PhysicalRegisterInfo::mapTo(RegisterRef RR, unsigned R) const {
166   if (RR.Reg == R)
167     return RR;
168   if (unsigned Idx = TRI.getSubRegIndex(R, RR.Reg))
169     return RegisterRef(R, TRI.composeSubRegIndexLaneMask(Idx, RR.Mask));
170   if (unsigned Idx = TRI.getSubRegIndex(RR.Reg, R)) {
171     const RegInfo &RI = RegInfos[R];
172     LaneBitmask RCM =
173         RI.RegClass ? RI.RegClass->LaneMask : LaneBitmask::getAll();
174     LaneBitmask M = TRI.reverseComposeSubRegIndexLaneMask(Idx, RR.Mask);
175     return RegisterRef(R, M & RCM);
176   }
177   llvm_unreachable("Invalid arguments: unrelated registers?");
178 }
179 
180 bool PhysicalRegisterInfo::equal_to(RegisterRef A, RegisterRef B) const {
181   if (!A.isReg() || !B.isReg()) {
182     // For non-regs, or comparing reg and non-reg, use only the Reg member.
183     return A.Reg == B.Reg;
184   }
185 
186   if (A.Reg == B.Reg)
187     return A.Mask == B.Mask;
188 
189   // Compare reg units lexicographically.
190   MCRegUnitMaskIterator AI(A.Reg, &getTRI());
191   MCRegUnitMaskIterator BI(B.Reg, &getTRI());
192   while (AI.isValid() && BI.isValid()) {
193     auto [AReg, AMask] = *AI;
194     auto [BReg, BMask] = *BI;
195 
196     // If both iterators point to a unit contained in both A and B, then
197     // compare the units.
198     if ((AMask & A.Mask).any() && (BMask & B.Mask).any()) {
199       if (AReg != BReg)
200         return false;
201       // Units are equal, move on to the next ones.
202       ++AI;
203       ++BI;
204       continue;
205     }
206 
207     if ((AMask & A.Mask).none())
208       ++AI;
209     if ((BMask & B.Mask).none())
210       ++BI;
211   }
212   // One or both have reached the end.
213   return static_cast<int>(AI.isValid()) == static_cast<int>(BI.isValid());
214 }
215 
216 bool PhysicalRegisterInfo::less(RegisterRef A, RegisterRef B) const {
217   if (!A.isReg() || !B.isReg()) {
218     // For non-regs, or comparing reg and non-reg, use only the Reg member.
219     return A.Reg < B.Reg;
220   }
221 
222   if (A.Reg == B.Reg)
223     return A.Mask < B.Mask;
224   if (A.Mask == B.Mask)
225     return A.Reg < B.Reg;
226 
227   // Compare reg units lexicographically.
228   llvm::MCRegUnitMaskIterator AI(A.Reg, &getTRI());
229   llvm::MCRegUnitMaskIterator BI(B.Reg, &getTRI());
230   while (AI.isValid() && BI.isValid()) {
231     auto [AReg, AMask] = *AI;
232     auto [BReg, BMask] = *BI;
233 
234     // If both iterators point to a unit contained in both A and B, then
235     // compare the units.
236     if ((AMask & A.Mask).any() && (BMask & B.Mask).any()) {
237       if (AReg != BReg)
238         return AReg < BReg;
239       // Units are equal, move on to the next ones.
240       ++AI;
241       ++BI;
242       continue;
243     }
244 
245     if ((AMask & A.Mask).none())
246       ++AI;
247     if ((BMask & B.Mask).none())
248       ++BI;
249   }
250   // One or both have reached the end: assume invalid < valid.
251   return static_cast<int>(AI.isValid()) < static_cast<int>(BI.isValid());
252 }
253 
254 void PhysicalRegisterInfo::print(raw_ostream &OS, RegisterRef A) const {
255   if (A.Reg == 0 || A.isReg()) {
256     if (0 < A.idx() && A.idx() < TRI.getNumRegs())
257       OS << TRI.getName(A.idx());
258     else
259       OS << printReg(A.idx(), &TRI);
260     OS << PrintLaneMaskShort(A.Mask);
261   } else if (A.isUnit()) {
262     OS << printRegUnit(A.idx(), &TRI);
263   } else {
264     assert(A.isMask());
265     // RegMask SS flag is preserved by idx().
266     unsigned Idx = Register::stackSlot2Index(A.idx());
267     const char *Fmt = Idx < 0x10000 ? "%04x" : "%08x";
268     OS << "M#" << format(Fmt, Idx);
269   }
270 }
271 
272 void PhysicalRegisterInfo::print(raw_ostream &OS, const RegisterAggr &A) const {
273   OS << '{';
274   for (unsigned U : A.units())
275     OS << ' ' << printRegUnit(U, &TRI);
276   OS << " }";
277 }
278 
279 bool RegisterAggr::hasAliasOf(RegisterRef RR) const {
280   if (RR.isMask())
281     return Units.anyCommon(PRI.getMaskUnits(RR.Reg));
282 
283   for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
284     std::pair<uint32_t, LaneBitmask> P = *U;
285     if ((P.second & RR.Mask).any())
286       if (Units.test(P.first))
287         return true;
288   }
289   return false;
290 }
291 
292 bool RegisterAggr::hasCoverOf(RegisterRef RR) const {
293   if (RR.isMask()) {
294     BitVector T(PRI.getMaskUnits(RR.Reg));
295     return T.reset(Units).none();
296   }
297 
298   for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
299     std::pair<uint32_t, LaneBitmask> P = *U;
300     if ((P.second & RR.Mask).any())
301       if (!Units.test(P.first))
302         return false;
303   }
304   return true;
305 }
306 
307 RegisterAggr &RegisterAggr::insert(RegisterRef RR) {
308   if (RR.isMask()) {
309     Units |= PRI.getMaskUnits(RR.Reg);
310     return *this;
311   }
312 
313   for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
314     std::pair<uint32_t, LaneBitmask> P = *U;
315     if ((P.second & RR.Mask).any())
316       Units.set(P.first);
317   }
318   return *this;
319 }
320 
321 RegisterAggr &RegisterAggr::insert(const RegisterAggr &RG) {
322   Units |= RG.Units;
323   return *this;
324 }
325 
326 RegisterAggr &RegisterAggr::intersect(RegisterRef RR) {
327   return intersect(RegisterAggr(PRI).insert(RR));
328 }
329 
330 RegisterAggr &RegisterAggr::intersect(const RegisterAggr &RG) {
331   Units &= RG.Units;
332   return *this;
333 }
334 
335 RegisterAggr &RegisterAggr::clear(RegisterRef RR) {
336   return clear(RegisterAggr(PRI).insert(RR));
337 }
338 
339 RegisterAggr &RegisterAggr::clear(const RegisterAggr &RG) {
340   Units.reset(RG.Units);
341   return *this;
342 }
343 
344 RegisterRef RegisterAggr::intersectWith(RegisterRef RR) const {
345   RegisterAggr T(PRI);
346   T.insert(RR).intersect(*this);
347   if (T.empty())
348     return RegisterRef();
349   RegisterRef NR = T.makeRegRef();
350   assert(NR);
351   return NR;
352 }
353 
354 RegisterRef RegisterAggr::clearIn(RegisterRef RR) const {
355   return RegisterAggr(PRI).insert(RR).clear(*this).makeRegRef();
356 }
357 
358 RegisterRef RegisterAggr::makeRegRef() const {
359   int U = Units.find_first();
360   if (U < 0)
361     return RegisterRef();
362 
363   // Find the set of all registers that are aliased to all the units
364   // in this aggregate.
365 
366   // Get all the registers aliased to the first unit in the bit vector.
367   BitVector Regs = PRI.getUnitAliases(U);
368   U = Units.find_next(U);
369 
370   // For each other unit, intersect it with the set of all registers
371   // aliased that unit.
372   while (U >= 0) {
373     Regs &= PRI.getUnitAliases(U);
374     U = Units.find_next(U);
375   }
376 
377   // If there is at least one register remaining, pick the first one,
378   // and consolidate the masks of all of its units contained in this
379   // aggregate.
380 
381   int F = Regs.find_first();
382   if (F <= 0)
383     return RegisterRef();
384 
385   LaneBitmask M;
386   for (MCRegUnitMaskIterator I(F, &PRI.getTRI()); I.isValid(); ++I) {
387     std::pair<uint32_t, LaneBitmask> P = *I;
388     if (Units.test(P.first))
389       M |= P.second;
390   }
391   return RegisterRef(F, M);
392 }
393 
394 RegisterAggr::ref_iterator::ref_iterator(const RegisterAggr &RG, bool End)
395     : Owner(&RG) {
396   for (int U = RG.Units.find_first(); U >= 0; U = RG.Units.find_next(U)) {
397     RegisterRef R = RG.PRI.getRefForUnit(U);
398     Masks[R.Reg] |= R.Mask;
399   }
400   Pos = End ? Masks.end() : Masks.begin();
401   Index = End ? Masks.size() : 0;
402 }
403 
404 raw_ostream &operator<<(raw_ostream &OS, const RegisterAggr &A) {
405   A.getPRI().print(OS, A);
406   return OS;
407 }
408 
409 raw_ostream &operator<<(raw_ostream &OS, const PrintLaneMaskShort &P) {
410   if (P.Mask.all())
411     return OS;
412   if (P.Mask.none())
413     return OS << ":*none*";
414 
415   LaneBitmask::Type Val = P.Mask.getAsInteger();
416   if ((Val & 0xffff) == Val)
417     return OS << ':' << format("%04llX", Val);
418   if ((Val & 0xffffffff) == Val)
419     return OS << ':' << format("%08llX", Val);
420   return OS << ':' << PrintLaneMask(P.Mask);
421 }
422 
423 } // namespace llvm::rdf
424