1 //===- RDFRegisters.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ADT/BitVector.h"
10 #include "llvm/CodeGen/MachineFunction.h"
11 #include "llvm/CodeGen/MachineInstr.h"
12 #include "llvm/CodeGen/MachineOperand.h"
13 #include "llvm/CodeGen/RDFRegisters.h"
14 #include "llvm/CodeGen/TargetRegisterInfo.h"
15 #include "llvm/MC/LaneBitmask.h"
16 #include "llvm/MC/MCRegisterInfo.h"
17 #include "llvm/Support/ErrorHandling.h"
18 #include "llvm/Support/Format.h"
19 #include "llvm/Support/MathExtras.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include <cassert>
22 #include <cstdint>
23 #include <set>
24 #include <utility>
25 
26 namespace llvm::rdf {
27 
28 PhysicalRegisterInfo::PhysicalRegisterInfo(const TargetRegisterInfo &tri,
29                                            const MachineFunction &mf)
30     : TRI(tri) {
31   RegInfos.resize(TRI.getNumRegs());
32 
33   BitVector BadRC(TRI.getNumRegs());
34   for (const TargetRegisterClass *RC : TRI.regclasses()) {
35     for (MCPhysReg R : *RC) {
36       RegInfo &RI = RegInfos[R];
37       if (RI.RegClass != nullptr && !BadRC[R]) {
38         if (RC->LaneMask != RI.RegClass->LaneMask) {
39           BadRC.set(R);
40           RI.RegClass = nullptr;
41         }
42       } else
43         RI.RegClass = RC;
44     }
45   }
46 
47   UnitInfos.resize(TRI.getNumRegUnits());
48 
49   for (uint32_t U = 0, NU = TRI.getNumRegUnits(); U != NU; ++U) {
50     if (UnitInfos[U].Reg != 0)
51       continue;
52     MCRegUnitRootIterator R(U, &TRI);
53     assert(R.isValid());
54     RegisterId F = *R;
55     ++R;
56     if (R.isValid()) {
57       UnitInfos[U].Mask = LaneBitmask::getAll();
58       UnitInfos[U].Reg = F;
59     } else {
60       for (MCRegUnitMaskIterator I(F, &TRI); I.isValid(); ++I) {
61         std::pair<uint32_t, LaneBitmask> P = *I;
62         UnitInfo &UI = UnitInfos[P.first];
63         UI.Reg = F;
64         if (P.second.any()) {
65           UI.Mask = P.second;
66         } else {
67           if (const TargetRegisterClass *RC = RegInfos[F].RegClass)
68             UI.Mask = RC->LaneMask;
69           else
70             UI.Mask = LaneBitmask::getAll();
71         }
72       }
73     }
74   }
75 
76   for (const uint32_t *RM : TRI.getRegMasks())
77     RegMasks.insert(RM);
78   for (const MachineBasicBlock &B : mf)
79     for (const MachineInstr &In : B)
80       for (const MachineOperand &Op : In.operands())
81         if (Op.isRegMask())
82           RegMasks.insert(Op.getRegMask());
83 
84   MaskInfos.resize(RegMasks.size() + 1);
85   for (uint32_t M = 1, NM = RegMasks.size(); M <= NM; ++M) {
86     BitVector PU(TRI.getNumRegUnits());
87     const uint32_t *MB = RegMasks.get(M);
88     for (unsigned I = 1, E = TRI.getNumRegs(); I != E; ++I) {
89       if (!(MB[I / 32] & (1u << (I % 32))))
90         continue;
91       for (MCRegUnit Unit : TRI.regunits(MCRegister::from(I)))
92         PU.set(Unit);
93     }
94     MaskInfos[M].Units = PU.flip();
95   }
96 
97   AliasInfos.resize(TRI.getNumRegUnits());
98   for (uint32_t U = 0, NU = TRI.getNumRegUnits(); U != NU; ++U) {
99     BitVector AS(TRI.getNumRegs());
100     for (MCRegUnitRootIterator R(U, &TRI); R.isValid(); ++R)
101       for (MCPhysReg S : TRI.superregs_inclusive(*R))
102         AS.set(S);
103     AliasInfos[U].Regs = AS;
104   }
105 }
106 
107 bool PhysicalRegisterInfo::alias(RegisterRef RA, RegisterRef RB) const {
108   return !disjoint(getUnits(RA), getUnits(RB));
109 }
110 
111 std::set<RegisterId> PhysicalRegisterInfo::getAliasSet(RegisterId Reg) const {
112   // Do not include Reg in the alias set.
113   std::set<RegisterId> AS;
114   assert(!RegisterRef::isUnitId(Reg) && "No units allowed");
115   if (RegisterRef::isMaskId(Reg)) {
116     // XXX SLOW
117     const uint32_t *MB = getRegMaskBits(Reg);
118     for (unsigned i = 1, e = TRI.getNumRegs(); i != e; ++i) {
119       if (MB[i / 32] & (1u << (i % 32)))
120         continue;
121       AS.insert(i);
122     }
123     return AS;
124   }
125 
126   assert(RegisterRef::isRegId(Reg));
127   for (MCRegAliasIterator AI(Reg, &TRI, false); AI.isValid(); ++AI)
128     AS.insert(*AI);
129 
130   return AS;
131 }
132 
133 std::set<RegisterId> PhysicalRegisterInfo::getUnits(RegisterRef RR) const {
134   std::set<RegisterId> Units;
135 
136   if (RR.Reg == 0)
137     return Units; // Empty
138 
139   if (RR.isReg()) {
140     if (RR.Mask.none())
141       return Units; // Empty
142     for (MCRegUnitMaskIterator UM(RR.idx(), &TRI); UM.isValid(); ++UM) {
143       auto [U, M] = *UM;
144       if (M.none() || (M & RR.Mask).any())
145         Units.insert(U);
146     }
147     return Units;
148   }
149 
150   assert(RR.isMask());
151   unsigned NumRegs = TRI.getNumRegs();
152   const uint32_t *MB = getRegMaskBits(RR.idx());
153   for (unsigned I = 0, E = (NumRegs + 31) / 32; I != E; ++I) {
154     uint32_t C = ~MB[I]; // Clobbered regs
155     if (I == 0)          // Reg 0 should be ignored
156       C &= maskLeadingOnes<unsigned>(31);
157     if (I + 1 == E && NumRegs % 32 != 0) // Last word may be partial
158       C &= maskTrailingOnes<unsigned>(NumRegs % 32);
159     if (C == 0)
160       continue;
161     while (C != 0) {
162       unsigned T = llvm::countr_zero(C);
163       unsigned CR = 32 * I + T; // Clobbered reg
164       for (MCRegUnit U : TRI.regunits(CR))
165         Units.insert(U);
166       C &= ~(1u << T);
167     }
168   }
169   return Units;
170 }
171 
172 RegisterRef PhysicalRegisterInfo::mapTo(RegisterRef RR, unsigned R) const {
173   if (RR.Reg == R)
174     return RR;
175   if (unsigned Idx = TRI.getSubRegIndex(R, RR.Reg))
176     return RegisterRef(R, TRI.composeSubRegIndexLaneMask(Idx, RR.Mask));
177   if (unsigned Idx = TRI.getSubRegIndex(RR.Reg, R)) {
178     const RegInfo &RI = RegInfos[R];
179     LaneBitmask RCM =
180         RI.RegClass ? RI.RegClass->LaneMask : LaneBitmask::getAll();
181     LaneBitmask M = TRI.reverseComposeSubRegIndexLaneMask(Idx, RR.Mask);
182     return RegisterRef(R, M & RCM);
183   }
184   llvm_unreachable("Invalid arguments: unrelated registers?");
185 }
186 
187 bool PhysicalRegisterInfo::equal_to(RegisterRef A, RegisterRef B) const {
188   if (!A.isReg() || !B.isReg()) {
189     // For non-regs, or comparing reg and non-reg, use only the Reg member.
190     return A.Reg == B.Reg;
191   }
192 
193   if (A.Reg == B.Reg)
194     return A.Mask == B.Mask;
195 
196   // Compare reg units lexicographically.
197   MCRegUnitMaskIterator AI(A.Reg, &getTRI());
198   MCRegUnitMaskIterator BI(B.Reg, &getTRI());
199   while (AI.isValid() && BI.isValid()) {
200     auto [AReg, AMask] = *AI;
201     auto [BReg, BMask] = *BI;
202 
203     // Lane masks are "none" for units that don't correspond to subregs
204     // e.g. a single unit in a leaf register, or aliased unit.
205     if (AMask.none())
206       AMask = LaneBitmask::getAll();
207     if (BMask.none())
208       BMask = LaneBitmask::getAll();
209 
210     // If both iterators point to a unit contained in both A and B, then
211     // compare the units.
212     if ((AMask & A.Mask).any() && (BMask & B.Mask).any()) {
213       if (AReg != BReg)
214         return false;
215       // Units are equal, move on to the next ones.
216       ++AI;
217       ++BI;
218       continue;
219     }
220 
221     if ((AMask & A.Mask).none())
222       ++AI;
223     if ((BMask & B.Mask).none())
224       ++BI;
225   }
226   // One or both have reached the end.
227   return static_cast<int>(AI.isValid()) == static_cast<int>(BI.isValid());
228 }
229 
230 bool PhysicalRegisterInfo::less(RegisterRef A, RegisterRef B) const {
231   if (!A.isReg() || !B.isReg()) {
232     // For non-regs, or comparing reg and non-reg, use only the Reg member.
233     return A.Reg < B.Reg;
234   }
235 
236   if (A.Reg == B.Reg)
237     return A.Mask < B.Mask;
238   if (A.Mask == B.Mask)
239     return A.Reg < B.Reg;
240 
241   // Compare reg units lexicographically.
242   llvm::MCRegUnitMaskIterator AI(A.Reg, &getTRI());
243   llvm::MCRegUnitMaskIterator BI(B.Reg, &getTRI());
244   while (AI.isValid() && BI.isValid()) {
245     auto [AReg, AMask] = *AI;
246     auto [BReg, BMask] = *BI;
247 
248     // Lane masks are "none" for units that don't correspond to subregs
249     // e.g. a single unit in a leaf register, or aliased unit.
250     if (AMask.none())
251       AMask = LaneBitmask::getAll();
252     if (BMask.none())
253       BMask = LaneBitmask::getAll();
254 
255     // If both iterators point to a unit contained in both A and B, then
256     // compare the units.
257     if ((AMask & A.Mask).any() && (BMask & B.Mask).any()) {
258       if (AReg != BReg)
259         return AReg < BReg;
260       // Units are equal, move on to the next ones.
261       ++AI;
262       ++BI;
263       continue;
264     }
265 
266     if ((AMask & A.Mask).none())
267       ++AI;
268     if ((BMask & B.Mask).none())
269       ++BI;
270   }
271   // One or both have reached the end: assume invalid < valid.
272   return static_cast<int>(AI.isValid()) < static_cast<int>(BI.isValid());
273 }
274 
275 void PhysicalRegisterInfo::print(raw_ostream &OS, RegisterRef A) const {
276   if (A.Reg == 0 || A.isReg()) {
277     if (0 < A.idx() && A.idx() < TRI.getNumRegs())
278       OS << TRI.getName(A.idx());
279     else
280       OS << printReg(A.idx(), &TRI);
281     OS << PrintLaneMaskShort(A.Mask);
282   } else if (A.isUnit()) {
283     OS << printRegUnit(A.idx(), &TRI);
284   } else {
285     assert(A.isMask());
286     // RegMask SS flag is preserved by idx().
287     unsigned Idx = Register::stackSlot2Index(A.idx());
288     const char *Fmt = Idx < 0x10000 ? "%04x" : "%08x";
289     OS << "M#" << format(Fmt, Idx);
290   }
291 }
292 
293 void PhysicalRegisterInfo::print(raw_ostream &OS, const RegisterAggr &A) const {
294   OS << '{';
295   for (unsigned U : A.units())
296     OS << ' ' << printRegUnit(U, &TRI);
297   OS << " }";
298 }
299 
300 bool RegisterAggr::hasAliasOf(RegisterRef RR) const {
301   if (RR.isMask())
302     return Units.anyCommon(PRI.getMaskUnits(RR.Reg));
303 
304   for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
305     std::pair<uint32_t, LaneBitmask> P = *U;
306     if (P.second.none() || (P.second & RR.Mask).any())
307       if (Units.test(P.first))
308         return true;
309   }
310   return false;
311 }
312 
313 bool RegisterAggr::hasCoverOf(RegisterRef RR) const {
314   if (RR.isMask()) {
315     BitVector T(PRI.getMaskUnits(RR.Reg));
316     return T.reset(Units).none();
317   }
318 
319   for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
320     std::pair<uint32_t, LaneBitmask> P = *U;
321     if (P.second.none() || (P.second & RR.Mask).any())
322       if (!Units.test(P.first))
323         return false;
324   }
325   return true;
326 }
327 
328 RegisterAggr &RegisterAggr::insert(RegisterRef RR) {
329   if (RR.isMask()) {
330     Units |= PRI.getMaskUnits(RR.Reg);
331     return *this;
332   }
333 
334   for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
335     std::pair<uint32_t, LaneBitmask> P = *U;
336     if (P.second.none() || (P.second & RR.Mask).any())
337       Units.set(P.first);
338   }
339   return *this;
340 }
341 
342 RegisterAggr &RegisterAggr::insert(const RegisterAggr &RG) {
343   Units |= RG.Units;
344   return *this;
345 }
346 
347 RegisterAggr &RegisterAggr::intersect(RegisterRef RR) {
348   return intersect(RegisterAggr(PRI).insert(RR));
349 }
350 
351 RegisterAggr &RegisterAggr::intersect(const RegisterAggr &RG) {
352   Units &= RG.Units;
353   return *this;
354 }
355 
356 RegisterAggr &RegisterAggr::clear(RegisterRef RR) {
357   return clear(RegisterAggr(PRI).insert(RR));
358 }
359 
360 RegisterAggr &RegisterAggr::clear(const RegisterAggr &RG) {
361   Units.reset(RG.Units);
362   return *this;
363 }
364 
365 RegisterRef RegisterAggr::intersectWith(RegisterRef RR) const {
366   RegisterAggr T(PRI);
367   T.insert(RR).intersect(*this);
368   if (T.empty())
369     return RegisterRef();
370   RegisterRef NR = T.makeRegRef();
371   assert(NR);
372   return NR;
373 }
374 
375 RegisterRef RegisterAggr::clearIn(RegisterRef RR) const {
376   return RegisterAggr(PRI).insert(RR).clear(*this).makeRegRef();
377 }
378 
379 RegisterRef RegisterAggr::makeRegRef() const {
380   int U = Units.find_first();
381   if (U < 0)
382     return RegisterRef();
383 
384   // Find the set of all registers that are aliased to all the units
385   // in this aggregate.
386 
387   // Get all the registers aliased to the first unit in the bit vector.
388   BitVector Regs = PRI.getUnitAliases(U);
389   U = Units.find_next(U);
390 
391   // For each other unit, intersect it with the set of all registers
392   // aliased that unit.
393   while (U >= 0) {
394     Regs &= PRI.getUnitAliases(U);
395     U = Units.find_next(U);
396   }
397 
398   // If there is at least one register remaining, pick the first one,
399   // and consolidate the masks of all of its units contained in this
400   // aggregate.
401 
402   int F = Regs.find_first();
403   if (F <= 0)
404     return RegisterRef();
405 
406   LaneBitmask M;
407   for (MCRegUnitMaskIterator I(F, &PRI.getTRI()); I.isValid(); ++I) {
408     std::pair<uint32_t, LaneBitmask> P = *I;
409     if (Units.test(P.first))
410       M |= P.second.none() ? LaneBitmask::getAll() : P.second;
411   }
412   return RegisterRef(F, M);
413 }
414 
415 RegisterAggr::ref_iterator::ref_iterator(const RegisterAggr &RG, bool End)
416     : Owner(&RG) {
417   for (int U = RG.Units.find_first(); U >= 0; U = RG.Units.find_next(U)) {
418     RegisterRef R = RG.PRI.getRefForUnit(U);
419     Masks[R.Reg] |= R.Mask;
420   }
421   Pos = End ? Masks.end() : Masks.begin();
422   Index = End ? Masks.size() : 0;
423 }
424 
425 raw_ostream &operator<<(raw_ostream &OS, const RegisterAggr &A) {
426   A.getPRI().print(OS, A);
427   return OS;
428 }
429 
430 raw_ostream &operator<<(raw_ostream &OS, const PrintLaneMaskShort &P) {
431   if (P.Mask.all())
432     return OS;
433   if (P.Mask.none())
434     return OS << ":*none*";
435 
436   LaneBitmask::Type Val = P.Mask.getAsInteger();
437   if ((Val & 0xffff) == Val)
438     return OS << ':' << format("%04llX", Val);
439   if ((Val & 0xffffffff) == Val)
440     return OS << ':' << format("%08llX", Val);
441   return OS << ':' << PrintLaneMask(P.Mask);
442 }
443 
444 } // namespace llvm::rdf
445