1 //===-- X86TileConfig.cpp - Tile Register Configure----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to config the shape of AMX physical registers
10 /// AMX register need to be configured before use. In X86PreTileConfig pass
11 /// the pldtilecfg instruction is inserted, however at that time we don't
12 /// know the shape of each physical tile registers, because the register
13 /// allocation is not done yet. This pass runs after egister allocation
14 /// pass. It collects the shape information of each physical tile register
15 /// and store the shape in the stack slot that is allocated for load config
16 /// to tile config register.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "X86.h"
21 #include "X86InstrBuilder.h"
22 #include "X86MachineFunctionInfo.h"
23 #include "X86RegisterInfo.h"
24 #include "X86Subtarget.h"
25 #include "llvm/CodeGen/LiveIntervals.h"
26 #include "llvm/CodeGen/MachineDominators.h"
27 #include "llvm/CodeGen/MachineFrameInfo.h"
28 #include "llvm/CodeGen/MachineFunctionPass.h"
29 #include "llvm/CodeGen/MachineInstr.h"
30 #include "llvm/CodeGen/MachineRegisterInfo.h"
31 #include "llvm/CodeGen/Passes.h"
32 #include "llvm/CodeGen/TargetInstrInfo.h"
33 #include "llvm/CodeGen/TargetRegisterInfo.h"
34 #include "llvm/CodeGen/TileShapeInfo.h"
35 #include "llvm/CodeGen/VirtRegMap.h"
36 #include "llvm/InitializePasses.h"
37 
38 using namespace llvm;
39 
40 #define DEBUG_TYPE "tile-config"
41 
42 namespace {
43 
44 class X86TileConfig : public MachineFunctionPass {
45   // context
46   MachineFunction *MF = nullptr;
47   const X86Subtarget *ST = nullptr;
48   const TargetRegisterInfo *TRI;
49   const TargetInstrInfo *TII;
50   MachineDominatorTree *DomTree = nullptr;
51   MachineRegisterInfo *MRI = nullptr;
52   VirtRegMap *VRM = nullptr;
53   LiveIntervals *LIS = nullptr;
54 
55   MachineInstr *getTileConfigPoint();
56   void tileConfig();
57 
58 public:
59   X86TileConfig() : MachineFunctionPass(ID) {}
60 
61   /// Return the pass name.
62   StringRef getPassName() const override { return "Tile Register Configure"; }
63 
64   /// X86TileConfig analysis usage.
65   void getAnalysisUsage(AnalysisUsage &AU) const override;
66 
67   /// Perform register allocation.
68   bool runOnMachineFunction(MachineFunction &mf) override;
69 
70   MachineFunctionProperties getRequiredProperties() const override {
71     return MachineFunctionProperties().set(
72         MachineFunctionProperties::Property::NoPHIs);
73   }
74 
75   static char ID;
76 };
77 
78 } // end anonymous namespace
79 
80 char X86TileConfig::ID = 0;
81 
82 INITIALIZE_PASS_BEGIN(X86TileConfig, "tileconfig", "Tile Register Configure",
83                       false, false)
84 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
85 INITIALIZE_PASS_DEPENDENCY(VirtRegMap)
86 INITIALIZE_PASS_END(X86TileConfig, "tileconfig", "Tile Register Configure",
87                     false, false)
88 
89 void X86TileConfig::getAnalysisUsage(AnalysisUsage &AU) const {
90   AU.addRequired<MachineDominatorTree>();
91   AU.addRequired<LiveIntervals>();
92   AU.addPreserved<SlotIndexes>();
93   AU.addRequired<VirtRegMap>();
94   AU.setPreservesAll();
95   MachineFunctionPass::getAnalysisUsage(AU);
96 }
97 
98 static unsigned getTilePhysRegIndex(Register PhysReg) {
99   assert((PhysReg >= X86::TMM0 && X86::TMM0 <= X86::TMM7) &&
100          "Tile register number is invalid");
101   return (PhysReg - X86::TMM0);
102 }
103 
104 static MachineInstr *
105 storeRegToStackSlot(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI,
106                     Register SrcReg, unsigned BitSize, int FrameIdx, int Offset,
107                     const TargetInstrInfo *TII, const TargetRegisterClass *RC,
108                     const TargetRegisterInfo *TRI) {
109 
110   unsigned SubIdx = (BitSize == 8) ? X86::sub_8bit : X86::sub_16bit;
111   unsigned Opc = (BitSize == 8) ? X86::MOV8mr : X86::MOV16mr;
112   if (BitSize == TRI->getRegSizeInBits(*RC))
113     SubIdx = 0;
114   MachineInstr *NewMI =
115       addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), FrameIdx,
116                         Offset)
117           .addReg(SrcReg, 0, SubIdx);
118   return NewMI;
119 }
120 
121 static MachineInstr *storeImmToStackSlot(MachineBasicBlock &MBB,
122                                          MachineBasicBlock::iterator MI,
123                                          int64_t Imm, unsigned BitSize,
124                                          int FrameIdx, int Offset,
125                                          const TargetInstrInfo *TII) {
126   unsigned Opc = (BitSize == 8) ? X86::MOV8mi : X86::MOV16mi;
127   return addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)),
128                            FrameIdx, Offset)
129       .addImm(Imm);
130 }
131 
132 MachineInstr *X86TileConfig::getTileConfigPoint() {
133   for (MachineBasicBlock &MBB : *MF) {
134 
135     // Traverse the basic block.
136     for (MachineInstr &MI : MBB)
137       // Refer X86PreTileConfig.cpp.
138       // We only support one tile config for now.
139       if (MI.getOpcode() == X86::PLDTILECFG)
140         return &MI;
141   }
142 
143   return nullptr;
144 }
145 
146 void X86TileConfig::tileConfig() {
147   MachineInstr *MI = getTileConfigPoint();
148   if (!MI)
149     return;
150   MachineBasicBlock *MBB = MI->getParent();
151   int SS = MI->getOperand(1).getIndex();
152   BitVector PhysRegs(TRI->getNumRegs());
153 
154   // Fill in the palette first.
155   auto *NewMI = storeImmToStackSlot(*MBB, *MI, 1, 8, SS, 0, TII);
156   LIS->InsertMachineInstrInMaps(*NewMI);
157   // Fill in the shape of each tile physical register.
158   for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) {
159     Register VirtReg = Register::index2VirtReg(i);
160     if (MRI->reg_nodbg_empty(VirtReg))
161       continue;
162     const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
163     if (RC.getID() != X86::TILERegClassID)
164       continue;
165     Register PhysReg = VRM->getPhys(VirtReg);
166     if (PhysRegs.test(PhysReg))
167       continue;
168     PhysRegs.set(PhysReg);
169     ShapeT Shape = VRM->getShape(VirtReg);
170     Register RowReg = Shape.getRow()->getReg();
171     Register ColReg = Shape.getCol()->getReg();
172 
173     // Here is the data format for the tile config.
174     // 0      palette
175     // 1      start_row
176     // 2-15   reserved, must be zero
177     // 16-17  tile0.colsb Tile 0 bytes per row.
178     // 18-19  tile1.colsb Tile 1 bytes per row.
179     // 20-21  tile2.colsb Tile 2 bytes per row.
180     // ... (sequence continues)
181     // 30-31  tile7.colsb Tile 7 bytes per row.
182     // 32-47  reserved, must be zero
183     // 48     tile0.rows Tile 0 rows.
184     // 49     tile1.rows Tile 1 rows.
185     // 50     tile2.rows Tile 2 rows.
186     // ... (sequence continues)
187     // 55     tile7.rows Tile 7 rows.
188     // 56-63  reserved, must be zero
189     unsigned Index = getTilePhysRegIndex(PhysReg);
190     int RowOffset = 48 + Index;
191     int ColOffset = 16 + Index * 2;
192 
193     unsigned BitSize = 8;
194     for (const auto &Pair : {std::make_pair(RowReg, RowOffset),
195                              std::make_pair(ColReg, ColOffset)}) {
196       int64_t Imm;
197       int ImmCount = 0;
198       // All def must be the same value, otherwise it is invalid MIs.
199       // Immediate is prefered.
200       for (const MachineOperand &MO : MRI->def_operands(Pair.first)) {
201         const auto *Inst = MO.getParent();
202         if (Inst->isMoveImmediate()) {
203           ImmCount++;
204           Imm = Inst->getOperand(1).getImm();
205           break;
206         }
207       }
208       auto StoreConfig = [&](int Offset) {
209         MachineInstr *NewMI = nullptr;
210         if (ImmCount)
211           NewMI = storeImmToStackSlot(*MBB, *MI, Imm, BitSize, SS, Offset, TII);
212         else {
213           const TargetRegisterClass *RC = MRI->getRegClass(Pair.first);
214           NewMI = storeRegToStackSlot(*MBB, *MI, Pair.first, BitSize, SS,
215                                       Offset, TII, RC, TRI);
216         }
217         SlotIndex SIdx = LIS->InsertMachineInstrInMaps(*NewMI);
218         if (!ImmCount) {
219           // Extend the live interval.
220           SmallVector<SlotIndex, 8> EndPoints = {SIdx.getRegSlot()};
221           LiveInterval &Int = LIS->getInterval(Pair.first);
222           LIS->extendToIndices(Int, EndPoints);
223         }
224       };
225       StoreConfig(Pair.second);
226       BitSize += 8;
227     }
228   }
229 }
230 
231 bool X86TileConfig::runOnMachineFunction(MachineFunction &mf) {
232   MF = &mf;
233   MRI = &mf.getRegInfo();
234   ST = &mf.getSubtarget<X86Subtarget>();
235   TRI = ST->getRegisterInfo();
236   TII = mf.getSubtarget().getInstrInfo();
237   DomTree = &getAnalysis<MachineDominatorTree>();
238   VRM = &getAnalysis<VirtRegMap>();
239   LIS = &getAnalysis<LiveIntervals>();
240 
241   if (VRM->isShapeMapEmpty())
242     return false;
243 
244   tileConfig();
245   return true;
246 }
247 
248 FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); }
249