1 //===-- X86PreTileConfig.cpp - Tile Register Configure---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to pre-config the shape of AMX register
10 /// AMX register need to be configured before use. The shape of AMX register
11 /// is encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12 /// The pldtilecfg is to config tile registers. It should dominator all AMX
13 /// instructions. The pldtilecfg produce a virtual cfg register and the cfg
14 /// register is used by all AMX instructions.
15 /// This pass is to find the common dominator of all AMX instructions and
16 /// insert the pldtilecfg instruction. Besides the cfg register that pldtilecfg
17 /// produces is inserted as the last operand of each AMX instruction. We use
18 /// this scheme to model the def-use relationship between AMX config instruction
19 /// and other AMX instructions. Below is an example.
20 ///
21 ///                        ----B1----
22 ///                       /           \
23 ///                      /             \
24 ///                    B2               B3
25 ///    %1:tile = PTILELOADDV        %2:tile = PTILELOADDV
26 ///
27 ///  is transformed to
28 ///
29 ///                            B1
30 ///                 %25:tilecfg = PLDTILECFG
31 ///                       /           \
32 ///                      /             \
33 ///  %1:tile = PTILELOADDV %25    %2:tile = PTILELOADDV %25
34 //
35 //===----------------------------------------------------------------------===//
36 
37 #include "X86.h"
38 #include "X86InstrBuilder.h"
39 #include "X86RegisterInfo.h"
40 #include "X86Subtarget.h"
41 #include "llvm/CodeGen/MachineDominators.h"
42 #include "llvm/CodeGen/MachineFunctionPass.h"
43 #include "llvm/CodeGen/MachineInstr.h"
44 #include "llvm/CodeGen/MachineRegisterInfo.h"
45 #include "llvm/CodeGen/Passes.h"
46 #include "llvm/CodeGen/TargetInstrInfo.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/TileShapeInfo.h"
49 #include "llvm/InitializePasses.h"
50 
51 using namespace llvm;
52 
53 #define DEBUG_TYPE "tile-pre-config"
54 
55 namespace {
56 
57 class X86PreTileConfig : public MachineFunctionPass {
58   // context
59   MachineFunction *MF = nullptr;
60   const X86Subtarget *ST = nullptr;
61   const TargetRegisterInfo *TRI;
62   const TargetInstrInfo *TII;
63   MachineDominatorTree *DomTree = nullptr;
64   MachineRegisterInfo *MRI = nullptr;
65 
66   MachineInstr *getTileConfigPoint();
67 
68 public:
69   X86PreTileConfig() : MachineFunctionPass(ID) {}
70 
71   /// Return the pass name.
72   StringRef getPassName() const override {
73     return "Tile Register Pre-configure";
74   }
75 
76   /// X86PreTileConfig analysis usage.
77   void getAnalysisUsage(AnalysisUsage &AU) const override;
78 
79   /// Perform register allocation.
80   bool runOnMachineFunction(MachineFunction &mf) override;
81 
82   static char ID;
83 };
84 
85 } // end anonymous namespace
86 
87 char X86PreTileConfig::ID = 0;
88 
89 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
90                       "Tile Register Configure", false, false)
91 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
92 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
93                     "Tile Register Configure", false, false)
94 
95 void X86PreTileConfig::getAnalysisUsage(AnalysisUsage &AU) const {
96   AU.setPreservesAll();
97   AU.addRequired<MachineDominatorTree>();
98   MachineFunctionPass::getAnalysisUsage(AU);
99 }
100 
101 static Register buildConfigMI(MachineBasicBlock::iterator MI, int FrameIdx,
102                               const TargetInstrInfo *TII,
103                               MachineRegisterInfo *MRI,
104                               const X86Subtarget *ST) {
105   auto *MBB = MI->getParent();
106 
107   // FIXME: AMX should assume AVX512 enabled.
108   if (ST->hasAVX512()) {
109     // Zero stack slot.
110     Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
111     BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORDZrr), Zmm)
112         .addReg(Zmm, RegState::Undef)
113         .addReg(Zmm, RegState::Undef);
114     addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSZmr)),
115                       FrameIdx)
116         .addReg(Zmm);
117   }
118 
119   // build psuedo ldtilecfg
120   Register VReg = MRI->createVirtualRegister(&X86::TILECFGRegClass);
121 
122   addFrameReference(
123       BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::PLDTILECFG), VReg), FrameIdx);
124 
125   return VReg;
126 }
127 
128 static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) {
129   unsigned Opcode = MI.getOpcode();
130   switch (Opcode) {
131   default:
132     llvm_unreachable("Unexpected machine instruction on tile");
133   case X86::PTILELOADDV:
134   case X86::PTDPBSSDV:
135   case X86::PTILEZEROV:
136     MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1));
137     MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2));
138     ShapeT Shape(&MO1, &MO2, MRI);
139     return Shape;
140   }
141 }
142 
143 MachineInstr *X86PreTileConfig::getTileConfigPoint() {
144   DenseMap<Register, ShapeT> PhysShapeInfo;
145   MachineBasicBlock *MBB = nullptr;
146   DenseSet<const MachineInstr *> MIs;
147   for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) {
148     Register VirtReg = Register::index2VirtReg(i);
149     if (MRI->reg_nodbg_empty(VirtReg))
150       continue;
151     const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
152     if (RC.getID() != X86::TILERegClassID)
153       continue;
154 
155     // Find the common dominator for all MI that define tile register.
156     for (const MachineOperand &MO : MRI->def_operands(VirtReg)) {
157       if (MO.isUndef())
158         continue;
159       const auto *MI = MO.getParent();
160       // PHI or IMPLICIT_DEF instructiion.
161       // There must be a input tile before PHI instruction.
162       if (MI->isTransient())
163         continue;
164       if (!MBB)
165         MBB = const_cast<MachineBasicBlock *>(MI->getParent());
166       MBB = DomTree->findNearestCommonDominator(
167           MBB, const_cast<MachineBasicBlock *>(MI->getParent()));
168 
169       // Collect the instructions that define shape.
170       ShapeT Shape = getShape(*MI, MRI);
171       std::array<MachineOperand *, 2> ShapeMOs = {Shape.getRow(),
172                                                   Shape.getCol()};
173       for (auto *ShapeMO : ShapeMOs) {
174         Register ShapeReg = ShapeMO->getReg();
175         for (const MachineOperand &MO : MRI->def_operands(ShapeReg)) {
176           const auto *ShapeMI = MO.getParent();
177           MIs.insert(ShapeMI);
178         }
179       }
180     }
181   }
182   if (!MBB)
183     return nullptr;
184   // This pass is before the pass of eliminating PHI node, so it
185   // is in SSA form.
186   assert(MRI->isSSA() && "Not SSA form in pre-tile config");
187   // Shape def should dominate tile config MBB.
188   //    def s           s1    s2
189   //     / \             \   /
190   //    /   \             \ /
191   //  conf               s3=phi(s1,s2)
192   //                       |
193   //                       c
194   //
195   for (const auto *MI : MIs) {
196     const MachineBasicBlock *ShapeMBB = MI->getParent();
197     if (DomTree->dominates(ShapeMBB, MBB))
198       continue;
199     if (MI->isMoveImmediate())
200       continue;
201     report_fatal_error(MF->getName() + ": Failed to config tile register, "
202                                        "please define the shape earlier");
203   }
204 
205   // ldtilecfg should be inserted after the MI that define the shape.
206   MachineBasicBlock::reverse_instr_iterator I, E;
207   for (I = MBB->instr_rbegin(), E = MBB->instr_rend(); I != E; ++I) {
208     auto *MI = &*I;
209     if (MIs.count(MI) && (!MI->isMoveImmediate()))
210       break;
211   }
212   MachineBasicBlock::iterator MII;
213   if (I == E)
214     MII = MBB->getFirstNonPHI();
215   else {
216     MII = MachineBasicBlock::iterator(&*I);
217     MII++;
218   }
219   return &*MII;
220 }
221 
222 static void addTileCFGUse(MachineFunction &MF, Register CFG) {
223   for (MachineBasicBlock &MBB : MF) {
224 
225     // Traverse the basic block.
226     for (MachineInstr &MI : MBB) {
227       unsigned Opcode = MI.getOpcode();
228       switch (Opcode) {
229       default:
230         break;
231       case X86::PTILELOADDV:
232       case X86::PTILESTOREDV:
233       case X86::PTDPBSSDV:
234       case X86::PTILEZEROV:
235         unsigned NumOperands = MI.getNumOperands();
236         MI.RemoveOperand(NumOperands - 1);
237         MI.addOperand(MF, MachineOperand::CreateReg(CFG, false));
238         break;
239       }
240     }
241   }
242 }
243 
244 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &mf) {
245   MF = &mf;
246   MRI = &mf.getRegInfo();
247   ST = &mf.getSubtarget<X86Subtarget>();
248   TRI = ST->getRegisterInfo();
249   TII = mf.getSubtarget().getInstrInfo();
250   DomTree = &getAnalysis<MachineDominatorTree>();
251 
252   MachineInstr *MI = getTileConfigPoint();
253   if (!MI)
254     return false;
255   unsigned Size = ST->getTileConfigSize();
256   Align Alignment = ST->getTileConfigAlignment();
257   int SS = mf.getFrameInfo().CreateStackObject(Size, Alignment, false);
258   Register CFG = buildConfigMI(MI, SS, TII, MRI, ST);
259   addTileCFGUse(mf, CFG);
260   return true;
261 }
262 
263 FunctionPass *llvm::createX86PreTileConfigPass() {
264   return new X86PreTileConfig();
265 }
266